1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: linear algebra"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport numpy as np 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerimport itertools 8*da0073e9SAndroid Build Coastguard Workerimport warnings 9*da0073e9SAndroid Build Coastguard Workerimport math 10*da0073e9SAndroid Build Coastguard Workerfrom math import inf, nan, isnan 11*da0073e9SAndroid Build Coastguard Workerimport re 12*da0073e9SAndroid Build Coastguard Workerimport random 13*da0073e9SAndroid Build Coastguard Workerfrom random import randrange 14*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 15*da0073e9SAndroid Build Coastguard Workerfrom functools import reduce, partial 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \ 18*da0073e9SAndroid Build Coastguard Worker (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, 19*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices, 20*da0073e9SAndroid Build Coastguard Worker make_fullrank_matrices_with_distinct_singular_values, 21*da0073e9SAndroid Build Coastguard Worker freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo, 22*da0073e9SAndroid Build Coastguard Worker setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest) 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \ 24*da0073e9SAndroid Build Coastguard Worker (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver, 25*da0073e9SAndroid Build Coastguard Worker onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, 26*da0073e9SAndroid Build Coastguard Worker skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA, 27*da0073e9SAndroid Build Coastguard Worker onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm, 28*da0073e9SAndroid Build Coastguard Worker dtypesIfMPS, largeTensorTest) 29*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 31*da0073e9SAndroid Build Coastguard Worker all_types, all_types_and_complex_and, floating_and_complex_types, integral_types, 32*da0073e9SAndroid Build Coastguard Worker floating_and_complex_types_and, floating_types_and, complex_types, 33*da0073e9SAndroid Build Coastguard Worker) 34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \ 35*da0073e9SAndroid Build Coastguard Worker _get_torch_cuda_version, CDNA2OrLater 36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel 37*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_mkldnn import bf32_on_and_off 38*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.binomial import Binomial 39*da0073e9SAndroid Build Coastguard Workerimport torch.backends.opt_einsum as opt_einsum 40*da0073e9SAndroid Build Coastguard Workerimport operator 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker# Protects against includes accidentally setting the default dtype 43*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() is torch.float32 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 46*da0073e9SAndroid Build Coastguard Worker import scipy 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerdef blaslt_supported_device(): 49*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 50*da0073e9SAndroid Build Coastguard Worker if torch.version.hip: 51*da0073e9SAndroid Build Coastguard Worker for arch in ['gfx90a', 'gfx94']: 52*da0073e9SAndroid Build Coastguard Worker if arch in torch.cuda.get_device_properties(0).gcnArchName: 53*da0073e9SAndroid Build Coastguard Worker return True 54*da0073e9SAndroid Build Coastguard Worker else: 55*da0073e9SAndroid Build Coastguard Worker return True 56*da0073e9SAndroid Build Coastguard Worker return False 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerdef set_tunableop_defaults(): 59*da0073e9SAndroid Build Coastguard Worker if not torch.cuda.is_available(): 60*da0073e9SAndroid Build Coastguard Worker # TunableOp not supported on CPU at this time. 61*da0073e9SAndroid Build Coastguard Worker return 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker # disable TunableOp and restore to default values 64*da0073e9SAndroid Build Coastguard Worker ordinal = torch.cuda.current_device() 65*da0073e9SAndroid Build Coastguard Worker filename = f"tunableop_results{ordinal}.csv" 66*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 67*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.tuning_enable(True) 68*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_filename(filename) # reset back to default filename for next unit test 69*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_duration(30) 70*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(100) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Workerclass TestLinalg(TestCase): 74*da0073e9SAndroid Build Coastguard Worker def setUp(self): 75*da0073e9SAndroid Build Coastguard Worker super(self.__class__, self).setUp() 76*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_tf32 = False 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 79*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_tf32 = True 80*da0073e9SAndroid Build Coastguard Worker super(self.__class__, self).tearDown() 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.cfloat) 85*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06}) 86*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(5e-3) 87*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(5e-3) 88*da0073e9SAndroid Build Coastguard Worker def test_inner(self, device, dtype): 89*da0073e9SAndroid Build Coastguard Worker def check(a_sizes_, b_sizes_): 90*da0073e9SAndroid Build Coastguard Worker for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)): 91*da0073e9SAndroid Build Coastguard Worker a = torch.randn(a_sizes, dtype=dtype, device=device) 92*da0073e9SAndroid Build Coastguard Worker b = torch.randn(b_sizes, dtype=dtype, device=device) 93*da0073e9SAndroid Build Coastguard Worker res = torch.inner(a, b) 94*da0073e9SAndroid Build Coastguard Worker ref = np.inner(a.cpu().numpy(), b.cpu().numpy()) 95*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) 96*da0073e9SAndroid Build Coastguard Worker out = torch.zeros_like(res) 97*da0073e9SAndroid Build Coastguard Worker torch.inner(a, b, out=out) 98*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, out) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker check([], []) # scalar x scalar 101*da0073e9SAndroid Build Coastguard Worker check([], [0]) # scalar x empty 102*da0073e9SAndroid Build Coastguard Worker check([], [3]) # scalar x 1D 103*da0073e9SAndroid Build Coastguard Worker check([], [2, 3, 4]) # scalar x 3D 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker check([0], [0]) # empty x empty 106*da0073e9SAndroid Build Coastguard Worker check([0], [2, 0]) # empty x 2D 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker check([2], [2]) # 1D x 1D 109*da0073e9SAndroid Build Coastguard Worker check([2], [3, 1, 2]) # 1D x 3D 110*da0073e9SAndroid Build Coastguard Worker check([2], [3, 0, 2]) # 1D x 3D empty 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker check([1, 2], [3, 2]) # 2D x 2D 113*da0073e9SAndroid Build Coastguard Worker check([1, 2], [3, 4, 2]) # 2D x 3D 114*da0073e9SAndroid Build Coastguard Worker check([2, 1, 3, 2], [1, 3, 2, 2]) # 4D x 4D 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker # Test error message 117*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 118*da0073e9SAndroid Build Coastguard Worker r"inner\(\) the last dimension must match on both " 119*da0073e9SAndroid Build Coastguard Worker r"input tensors but got shapes \[2, 3\] and \[2, 2\]"): 120*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, device=device, dtype=dtype).inner(torch.randn(2, 2, device=device, dtype=dtype)) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker # Tests torch.outer, and its alias, torch.ger, vs. NumPy 123*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-1}) 124*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 125*da0073e9SAndroid Build Coastguard Worker def test_outer(self, device, dtype): 126*da0073e9SAndroid Build Coastguard Worker def run_test_case(a, b): 127*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 128*da0073e9SAndroid Build Coastguard Worker a_np = a.to(torch.double).cpu().numpy() 129*da0073e9SAndroid Build Coastguard Worker b_np = b.to(torch.double).cpu().numpy() 130*da0073e9SAndroid Build Coastguard Worker exact_dtype = False 131*da0073e9SAndroid Build Coastguard Worker else: 132*da0073e9SAndroid Build Coastguard Worker a_np = a.cpu().numpy() 133*da0073e9SAndroid Build Coastguard Worker b_np = b.cpu().numpy() 134*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 135*da0073e9SAndroid Build Coastguard Worker expected = np.outer(a_np, b_np) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.outer(a, b), expected, exact_dtype=False) 138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Tensor.outer(a, b), expected, exact_dtype=False) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ger(a, b), expected, exact_dtype=False) 141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Tensor.ger(a, b), expected, exact_dtype=False) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker # test out variant 144*da0073e9SAndroid Build Coastguard Worker out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype) 145*da0073e9SAndroid Build Coastguard Worker torch.outer(a, b, out=out) 146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected, exact_dtype=False) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype) 149*da0073e9SAndroid Build Coastguard Worker torch.ger(a, b, out=out) 150*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected, exact_dtype=False) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker a = torch.randn(50).to(device=device, dtype=dtype) 153*da0073e9SAndroid Build Coastguard Worker b = torch.randn(50).to(device=device, dtype=dtype) 154*da0073e9SAndroid Build Coastguard Worker run_test_case(a, b) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker # test 0 strided tensor 157*da0073e9SAndroid Build Coastguard Worker zero_strided = torch.randn(1).to(device=device, dtype=dtype).expand(50) 158*da0073e9SAndroid Build Coastguard Worker run_test_case(zero_strided, b) 159*da0073e9SAndroid Build Coastguard Worker run_test_case(a, zero_strided) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank_removed_error(self, device): 162*da0073e9SAndroid Build Coastguard Worker a = make_tensor(5, 5, device=device, dtype=torch.float32) 163*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 164*da0073e9SAndroid Build Coastguard Worker torch.matrix_rank(a) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker def test_solve_removed_error(self, device): 167*da0073e9SAndroid Build Coastguard Worker a = make_tensor(5, 5, device=device, dtype=torch.float32) 168*da0073e9SAndroid Build Coastguard Worker b = make_tensor(5, 1, device=device, dtype=torch.float32) 169*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 170*da0073e9SAndroid Build Coastguard Worker torch.solve(b, a) 171*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 172*da0073e9SAndroid Build Coastguard Worker b.solve(a) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def test_eig_removed_error(self, device): 175*da0073e9SAndroid Build Coastguard Worker a = make_tensor(5, 5, device=device, dtype=torch.float32) 176*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 177*da0073e9SAndroid Build Coastguard Worker torch.eig(a) 178*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 179*da0073e9SAndroid Build Coastguard Worker a.eig() 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker def test_symeig_removed_error(self, device): 182*da0073e9SAndroid Build Coastguard Worker a = make_tensor(5, 5, device=device, dtype=torch.float32) 183*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 184*da0073e9SAndroid Build Coastguard Worker torch.symeig(a) 185*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 186*da0073e9SAndroid Build Coastguard Worker a.symeig() 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def test_lstsq_removed_error(self, device): 189*da0073e9SAndroid Build Coastguard Worker a = make_tensor(5, 5, device=device, dtype=torch.float32) 190*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 191*da0073e9SAndroid Build Coastguard Worker torch.lstsq(a, a) 192*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): 193*da0073e9SAndroid Build Coastguard Worker a.lstsq(a) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 196*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 197*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("flaky, needs investigation") 198*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 199*da0073e9SAndroid Build Coastguard Worker def test_linalg_lstsq(self, device, dtype): 200*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_well_conditioned_matrix 201*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 202*da0073e9SAndroid Build Coastguard Worker drivers = ('gels', 'gelsy', 'gelsd', 'gelss', None) 203*da0073e9SAndroid Build Coastguard Worker else: 204*da0073e9SAndroid Build Coastguard Worker drivers = ('gels', None) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def check_solution_correctness(a, b, sol): 207*da0073e9SAndroid Build Coastguard Worker sol2 = a.pinverse() @ b 208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def check_correctness_ref(a, b, res, ref, driver="default"): 211*da0073e9SAndroid Build Coastguard Worker def apply_if_not_empty(t, f): 212*da0073e9SAndroid Build Coastguard Worker if t.numel(): 213*da0073e9SAndroid Build Coastguard Worker return f(t) 214*da0073e9SAndroid Build Coastguard Worker else: 215*da0073e9SAndroid Build Coastguard Worker return t 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def select_if_not_empty(t, i): 218*da0073e9SAndroid Build Coastguard Worker selected = apply_if_not_empty(t, lambda x: x.select(0, i)) 219*da0073e9SAndroid Build Coastguard Worker return selected 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker m = a.size(-2) 222*da0073e9SAndroid Build Coastguard Worker n = a.size(-1) 223*da0073e9SAndroid Build Coastguard Worker nrhs = b.size(-1) 224*da0073e9SAndroid Build Coastguard Worker batch_size = int(np.prod(a.shape[:-2])) 225*da0073e9SAndroid Build Coastguard Worker if batch_size == 0: 226*da0073e9SAndroid Build Coastguard Worker batch_size = 1 227*da0073e9SAndroid Build Coastguard Worker a_3d = a.view(batch_size, m, n) 228*da0073e9SAndroid Build Coastguard Worker b_3d = b.view(batch_size, m, nrhs) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker solution_3d = res.solution.view(batch_size, n, nrhs) 231*da0073e9SAndroid Build Coastguard Worker residuals_2d = apply_if_not_empty(res.residuals, lambda t: t.view(-1, nrhs)) 232*da0073e9SAndroid Build Coastguard Worker rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1)) 233*da0073e9SAndroid Build Coastguard Worker singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1]) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker if a.numel() > 0: 236*da0073e9SAndroid Build Coastguard Worker for i in range(batch_size): 237*da0073e9SAndroid Build Coastguard Worker sol, residuals, rank, singular_values = ref( 238*da0073e9SAndroid Build Coastguard Worker a_3d.select(0, i).numpy(), 239*da0073e9SAndroid Build Coastguard Worker b_3d.select(0, i).numpy() 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker # Singular values are None when lapack_driver='gelsy' in SciPy 242*da0073e9SAndroid Build Coastguard Worker if singular_values is None: 243*da0073e9SAndroid Build Coastguard Worker singular_values = [] 244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5) 245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5) 246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker # SciPy and NumPy operate only on non-batched input and 249*da0073e9SAndroid Build Coastguard Worker # return an empty array with shape (0,) if rank(a) != n 250*da0073e9SAndroid Build Coastguard Worker # in PyTorch the batched inputs are supported and 251*da0073e9SAndroid Build Coastguard Worker # matrices in the batched input can have different ranks 252*da0073e9SAndroid Build Coastguard Worker # we compute residuals only if all matrices have rank == n 253*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/56483 254*da0073e9SAndroid Build Coastguard Worker if m > n: 255*da0073e9SAndroid Build Coastguard Worker if torch.all(rank_1d == n): 256*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 257*da0073e9SAndroid Build Coastguard Worker residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5, exact_dtype=False 258*da0073e9SAndroid Build Coastguard Worker ) 259*da0073e9SAndroid Build Coastguard Worker else: 260*da0073e9SAndroid Build Coastguard Worker self.assertTrue(residuals_2d.numel() == 0) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker else: 263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs)) 264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.rank.shape, a.shape[:-2]) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker # residuals are not always computed (and have non-zero shape) 267*da0073e9SAndroid Build Coastguard Worker if m > n and driver != "gelsy": 268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0)) 269*da0073e9SAndroid Build Coastguard Worker else: 270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.residuals.shape, (0, )) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker # singular_values are not always computed (and have non-zero shape) 273*da0073e9SAndroid Build Coastguard Worker if driver == "default" or driver == "gelsd" or driver == "gelss": 274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n))) 275*da0073e9SAndroid Build Coastguard Worker else: 276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.singular_values.shape, (0, )) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker def check_correctness_scipy(a, b, res, driver, cond): 279*da0073e9SAndroid Build Coastguard Worker # SciPy provides 3 driver options: gelsd, gelss, gelsy 280*da0073e9SAndroid Build Coastguard Worker if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'): 281*da0073e9SAndroid Build Coastguard Worker import scipy.linalg 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def scipy_ref(a, b): 284*da0073e9SAndroid Build Coastguard Worker return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond) 285*da0073e9SAndroid Build Coastguard Worker check_correctness_ref(a, b, res, scipy_ref, driver=driver) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker def check_correctness_numpy(a, b, res, driver, rcond): 288*da0073e9SAndroid Build Coastguard Worker # NumPy uses only gelsd routine 289*da0073e9SAndroid Build Coastguard Worker if driver == 'gelsd': 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker def numpy_ref(a, b): 292*da0073e9SAndroid Build Coastguard Worker return np.linalg.lstsq(a, b, rcond=rcond) 293*da0073e9SAndroid Build Coastguard Worker check_correctness_ref(a, b, res, numpy_ref) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker ms = [2 ** i for i in range(5)] 296*da0073e9SAndroid Build Coastguard Worker m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms] 297*da0073e9SAndroid Build Coastguard Worker # cases m < n are only supported on CPU and for cuSOLVER path on CUDA 298*da0073e9SAndroid Build Coastguard Worker m_l_n_sizes = [(m // 2, m) for m in ms] 299*da0073e9SAndroid Build Coastguard Worker include_m_l_n_case = (has_cusolver() or device == 'cpu') 300*da0073e9SAndroid Build Coastguard Worker matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if include_m_l_n_case else []) 301*da0073e9SAndroid Build Coastguard Worker batches = [(), (2,), (2, 2), (2, 2, 2)] 302*da0073e9SAndroid Build Coastguard Worker # we generate matrices with singular values sampled from a normal distribution, 303*da0073e9SAndroid Build Coastguard Worker # that is why we use `cond=1.0`, the mean to cut roughly half of all 304*da0073e9SAndroid Build Coastguard Worker # the singular values and compare whether torch.linalg.lstsq agrees with 305*da0073e9SAndroid Build Coastguard Worker # SciPy and NumPy. 306*da0073e9SAndroid Build Coastguard Worker # if rcond is True then set value for it based on the used algorithm 307*da0073e9SAndroid Build Coastguard Worker # rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance 308*da0073e9SAndroid Build Coastguard Worker rconds = (None, True, -1) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds): 311*da0073e9SAndroid Build Coastguard Worker # keep the rcond value if it is None or -1, set the driver specific value if it is True 312*da0073e9SAndroid Build Coastguard Worker if rcond and rcond != -1: 313*da0073e9SAndroid Build Coastguard Worker if driver in ('gelss', 'gelsd'): 314*da0073e9SAndroid Build Coastguard Worker # SVD based algorithm; set to zero roughly half of all the singular values 315*da0073e9SAndroid Build Coastguard Worker rcond = 1.0 316*da0073e9SAndroid Build Coastguard Worker else: 317*da0073e9SAndroid Build Coastguard Worker # driver == 'gelsy' 318*da0073e9SAndroid Build Coastguard Worker # QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests 319*da0073e9SAndroid Build Coastguard Worker # so we skip this case 320*da0073e9SAndroid Build Coastguard Worker continue 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker # specifying rcond value has no effect for gels driver so no need to run the tests again 323*da0073e9SAndroid Build Coastguard Worker if driver == 'gels' and rcond is not None: 324*da0073e9SAndroid Build Coastguard Worker continue 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker shape = batch + matrix_size 327*da0073e9SAndroid Build Coastguard Worker a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device) 328*da0073e9SAndroid Build Coastguard Worker b = torch.rand(*shape, dtype=dtype, device=device) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker m = a.size(-2) 331*da0073e9SAndroid Build Coastguard Worker n = a.size(-1) 332*da0073e9SAndroid Build Coastguard Worker res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver) 333*da0073e9SAndroid Build Coastguard Worker sol = res.solution 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker # Only checks gelsd, gelss, gelsy drivers 336*da0073e9SAndroid Build Coastguard Worker check_correctness_scipy(a, b, res, driver, rcond) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker # Only checks gelsd driver 339*da0073e9SAndroid Build Coastguard Worker check_correctness_numpy(a, b, res, driver, rcond) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker # gels driver is not checked by comparing to NumPy or SciPy implementation 342*da0073e9SAndroid Build Coastguard Worker # because NumPy and SciPy do not implement this driver 343*da0073e9SAndroid Build Coastguard Worker if driver == 'gels' and rcond is None: 344*da0073e9SAndroid Build Coastguard Worker check_solution_correctness(a, b, sol) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 347*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 348*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 349*da0073e9SAndroid Build Coastguard Worker def test_linalg_lstsq_batch_broadcasting(self, device, dtype): 350*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_well_conditioned_matrix 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def check_correctness(a, b): 353*da0073e9SAndroid Build Coastguard Worker sol = torch.linalg.lstsq(a, b).solution 354*da0073e9SAndroid Build Coastguard Worker sol2 = a.pinverse() @ b 355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sol, sol2, rtol=1e-5, atol=1e-5) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker ms = [2 ** i for i in range(5)] 358*da0073e9SAndroid Build Coastguard Worker batches = [(), (0,), (2,), (2, 2), (2, 2, 2)] 359*da0073e9SAndroid Build Coastguard Worker # the case when a single matrix is batch-broadcasted over the rhs 360*da0073e9SAndroid Build Coastguard Worker for m, batch in itertools.product(ms, batches): 361*da0073e9SAndroid Build Coastguard Worker a = random_well_conditioned_matrix(m, m, dtype=dtype, device=device).view(*([1] * len(batch)), m, m) 362*da0073e9SAndroid Build Coastguard Worker b = torch.rand(*(batch + (m, m)), dtype=dtype, device=device) 363*da0073e9SAndroid Build Coastguard Worker check_correctness(a, b) 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker # cases with broadcastable shapes 366*da0073e9SAndroid Build Coastguard Worker for m in ms: 367*da0073e9SAndroid Build Coastguard Worker a = random_well_conditioned_matrix(1, 3, 1, 3, m, m, dtype=dtype, device=device) 368*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 1, 3, 1, m, m // 2, dtype=dtype, device=device) 369*da0073e9SAndroid Build Coastguard Worker check_correctness(a, b) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker # rhs are vectors, not matrices in this test 372*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 1, 3, 1, m, dtype=dtype, device=device) 373*da0073e9SAndroid Build Coastguard Worker # unsqueeze for b because `check_correctness` checks against 374*da0073e9SAndroid Build Coastguard Worker # a.pinverse() @ b, which requires b to be a matrix 375*da0073e9SAndroid Build Coastguard Worker check_correctness(a, b.unsqueeze(-1)) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker a = random_well_conditioned_matrix(3, 1, 3, 1, m, m, dtype=dtype, device=device) 378*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, 3, 1, 3, m, m // 2, dtype=dtype, device=device) 379*da0073e9SAndroid Build Coastguard Worker check_correctness(a, b) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker # rhs are vectors, not matrices in this test 382*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, 3, 1, 3, m, dtype=dtype, device=device) 383*da0073e9SAndroid Build Coastguard Worker check_correctness(a, b.unsqueeze(-1)) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 386*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 387*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 388*da0073e9SAndroid Build Coastguard Worker def test_linalg_lstsq_input_checks(self, device, dtype): 389*da0073e9SAndroid Build Coastguard Worker # check empty inputs 390*da0073e9SAndroid Build Coastguard Worker # empty batches 391*da0073e9SAndroid Build Coastguard Worker a = torch.rand(0, 0, 3, 3, dtype=dtype, device=device) 392*da0073e9SAndroid Build Coastguard Worker b = torch.rand(0, 0, 3, 2, dtype=dtype, device=device) 393*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 394*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b)[0], 395*da0073e9SAndroid Build Coastguard Worker torch.zeros(0, 0, 3, 2, dtype=dtype, device=device) 396*da0073e9SAndroid Build Coastguard Worker ) 397*da0073e9SAndroid Build Coastguard Worker # empty a and b 398*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, 0, 0, dtype=dtype, device=device) 399*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 0, 0, dtype=dtype, device=device) 400*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 401*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b)[0], 402*da0073e9SAndroid Build Coastguard Worker torch.zeros(2, 2, 0, 0, dtype=dtype, device=device) 403*da0073e9SAndroid Build Coastguard Worker ) 404*da0073e9SAndroid Build Coastguard Worker # empty a and b 405*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device) 406*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 3, 0, dtype=dtype, device=device) 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 408*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b)[0], 409*da0073e9SAndroid Build Coastguard Worker torch.zeros(2, 2, 0, 0, dtype=dtype, device=device) 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker # empty a but not b 412*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device) 413*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 3, 2, dtype=dtype, device=device) 414*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 415*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b)[0], 416*da0073e9SAndroid Build Coastguard Worker torch.zeros(2, 2, 0, 2, dtype=dtype, device=device) 417*da0073e9SAndroid Build Coastguard Worker ) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker # empty a and b 420*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cpu': 421*da0073e9SAndroid Build Coastguard Worker # only CPU since CUDA does not support overdetermined systems 422*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, 0, 3, dtype=dtype, device=device) 423*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 0, 3, dtype=dtype, device=device) 424*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 425*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b)[0], 426*da0073e9SAndroid Build Coastguard Worker torch.zeros(2, 2, 3, 3, dtype=dtype, device=device) 427*da0073e9SAndroid Build Coastguard Worker ) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, dtype=dtype, device=device) 430*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, dtype=dtype, device=device) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'): 433*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(b, b) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'): 436*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device)) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'): 439*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'): 442*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b.unsqueeze(-1)) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 1, 1, dtype=dtype, device=device) 445*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 1, dtype=dtype, device=device) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'): 448*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker def complement_device(device): 451*da0073e9SAndroid Build Coastguard Worker if device == 'cpu' and torch.cuda.is_available(): 452*da0073e9SAndroid Build Coastguard Worker return 'cuda' 453*da0073e9SAndroid Build Coastguard Worker else: 454*da0073e9SAndroid Build Coastguard Worker return 'cpu' 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device) 457*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device)) 458*da0073e9SAndroid Build Coastguard Worker if a.device != b.device: 459*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'be on the same device'): 460*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long() 463*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'the same dtype'): 464*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device) 467*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 2, dtype=dtype, device=device) 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker if device != 'cpu': 470*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '`driver` other than `gels` is not supported on CUDA'): 471*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b, driver='fictitious_driver') 472*da0073e9SAndroid Build Coastguard Worker # if on cpu 473*da0073e9SAndroid Build Coastguard Worker else: 474*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'parameter `driver` should be one of \(gels, gelsy, gelsd, gelss\)'): 475*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b, driver='fictitious_driver') 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker # cuSOLVER path supports underdetermined systems 478*da0073e9SAndroid Build Coastguard Worker version = torch.testing._internal.common_cuda._get_torch_cuda_version() 479*da0073e9SAndroid Build Coastguard Worker cusolver_not_available = (version < (10, 1)) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker if device != 'cpu' and cusolver_not_available: 482*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, dtype=dtype, device=device) 483*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 1, dtype=dtype, device=device) 484*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'): 485*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq(a, b) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 488*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 489*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 490*da0073e9SAndroid Build Coastguard Worker def test_cholesky(self, device, dtype): 491*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker def run_test(shape, batch, contiguous): 494*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 495*da0073e9SAndroid Build Coastguard Worker if A.numel() > 0 and not contiguous: 496*da0073e9SAndroid Build Coastguard Worker A = A.mT 497*da0073e9SAndroid Build Coastguard Worker self.assertFalse(A.is_contiguous()) 498*da0073e9SAndroid Build Coastguard Worker expected_L = np.linalg.cholesky(A.cpu().numpy()) 499*da0073e9SAndroid Build Coastguard Worker actual_L = torch.linalg.cholesky(A) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker # For fp32 individual entries in matrices can differ between PyTorch and NumPy 502*da0073e9SAndroid Build Coastguard Worker # Let's compare the norms of matrices instead 503*da0073e9SAndroid Build Coastguard Worker if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: 504*da0073e9SAndroid Build Coastguard Worker # axis is specified to calculate matrix norm for batched input 505*da0073e9SAndroid Build Coastguard Worker expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) 506*da0073e9SAndroid Build Coastguard Worker actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) 507*da0073e9SAndroid Build Coastguard Worker # Compare the norms with standard tolerances 508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_norm, expected_norm) 509*da0073e9SAndroid Build Coastguard Worker # and individual values with a higher tolerance 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) 511*da0073e9SAndroid Build Coastguard Worker else: 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_L, expected_L) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker shapes = (0, 3, 5) 515*da0073e9SAndroid Build Coastguard Worker batches = ((), (3, ), (2, 2)) 516*da0073e9SAndroid Build Coastguard Worker larger_input_case = [(100, (5, ), True)] 517*da0073e9SAndroid Build Coastguard Worker for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case: 518*da0073e9SAndroid Build Coastguard Worker run_test(shape, batch, contiguous) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker # check the out= variant 521*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device) 522*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(A) 523*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.cholesky(A, out=out) 524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 525*da0073e9SAndroid Build Coastguard Worker expected = torch.linalg.cholesky(A) 526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker # check the upper= variant 529*da0073e9SAndroid Build Coastguard Worker expected = torch.linalg.cholesky(A).mH 530*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.cholesky(A, upper=True) 531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 534*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 535*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 536*da0073e9SAndroid Build Coastguard Worker def test_cholesky_errors_and_warnings(self, device, dtype): 537*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # cholesky requires the input to be a square matrix or batch of square matrices 540*da0073e9SAndroid Build Coastguard Worker A = torch.randn(2, 3, device=device, dtype=dtype) 541*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 542*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A) 543*da0073e9SAndroid Build Coastguard Worker A = torch.randn(2, 2, 3, device=device, dtype=dtype) 544*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 545*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A) 546*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'): 547*da0073e9SAndroid Build Coastguard Worker np.linalg.cholesky(A.cpu().numpy()) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker # cholesky requires the input to be at least 2 dimensional tensor 550*da0073e9SAndroid Build Coastguard Worker A = torch.randn(2, device=device, dtype=dtype) 551*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): 552*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A) 553*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(np.linalg.LinAlgError, 554*da0073e9SAndroid Build Coastguard Worker r'1-dimensional array given\. Array must be at least two-dimensional'): 555*da0073e9SAndroid Build Coastguard Worker np.linalg.cholesky(A.cpu().numpy()) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker # if the input matrix is not positive definite, an error should be raised 558*da0073e9SAndroid Build Coastguard Worker A = torch.eye(3, 3, dtype=dtype, device=device) 559*da0073e9SAndroid Build Coastguard Worker A[-1, -1] = 0 # Now A is not positive definite 560*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'): 561*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A) 562*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): 563*da0073e9SAndroid Build Coastguard Worker np.linalg.cholesky(A.cpu().numpy()) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker # if at least one matrix in the batch is singular, an error should be raised 566*da0073e9SAndroid Build Coastguard Worker A = torch.eye(3, 3, dtype=dtype, device=device) 567*da0073e9SAndroid Build Coastguard Worker A = A.reshape((1, 3, 3)) 568*da0073e9SAndroid Build Coastguard Worker A = A.repeat(5, 1, 1) 569*da0073e9SAndroid Build Coastguard Worker A[4, -1, -1] = 0 # Now A[4] is not positive definite 570*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'): 571*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker # if out tensor with wrong shape is passed a warning is given 574*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) 575*da0073e9SAndroid Build Coastguard Worker out = torch.empty(2, 3, dtype=dtype, device=device) 576*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 577*da0073e9SAndroid Build Coastguard Worker # Trigger warning 578*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A, out=out) 579*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 581*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 584*da0073e9SAndroid Build Coastguard Worker out = torch.empty(*A.shape, dtype=torch.int, device=device) 585*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got int instead"): 586*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A, out=out) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker # device should match 589*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 590*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 591*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=wrong_device, dtype=dtype) 592*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 593*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky(A, out=out) 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py 596*da0073e9SAndroid Build Coastguard Worker @slowTest 597*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 598*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 599*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 600*da0073e9SAndroid Build Coastguard Worker def test_old_cholesky_batched_many_batches(self, device, dtype): 601*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_symmetric_pd_matrix 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker def cholesky_test_helper(n, batchsize, device, upper): 604*da0073e9SAndroid Build Coastguard Worker A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) 605*da0073e9SAndroid Build Coastguard Worker chol_fact = torch.cholesky(A, upper=upper) 606*da0073e9SAndroid Build Coastguard Worker if upper: 607*da0073e9SAndroid Build Coastguard Worker # Correctness check 608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, chol_fact.mT.matmul(chol_fact)) 609*da0073e9SAndroid Build Coastguard Worker # Upper triangular check 610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chol_fact, chol_fact.triu()) 611*da0073e9SAndroid Build Coastguard Worker else: 612*da0073e9SAndroid Build Coastguard Worker # Correctness check 613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, chol_fact.matmul(chol_fact.mT)) 614*da0073e9SAndroid Build Coastguard Worker # Lower triangular check 615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chol_fact, chol_fact.tril()) 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker for upper, batchsize in itertools.product([True, False], [262144, 524288]): 618*da0073e9SAndroid Build Coastguard Worker cholesky_test_helper(2, batchsize, device, upper) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 621*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 622*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 623*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 624*da0073e9SAndroid Build Coastguard Worker def test_old_cholesky_batched(self, device, dtype): 625*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker def cholesky_test_helper(n, batch_dims, upper): 628*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device) 629*da0073e9SAndroid Build Coastguard Worker cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) 630*da0073e9SAndroid Build Coastguard Worker cholesky_exp = cholesky_exp.reshape_as(A) 631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]): 634*da0073e9SAndroid Build Coastguard Worker cholesky_test_helper(3, batchsize, upper) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 637*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 638*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 639*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 640*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.01) 641*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.01) 642*da0073e9SAndroid Build Coastguard Worker def test_old_cholesky(self, device, dtype): 643*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker # default Case 648*da0073e9SAndroid Build Coastguard Worker C = torch.cholesky(A) 649*da0073e9SAndroid Build Coastguard Worker B = torch.mm(C, C.t().conj()) 650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, B, atol=1e-14, rtol=0) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker # test Upper Triangular 653*da0073e9SAndroid Build Coastguard Worker U = torch.cholesky(A, True) 654*da0073e9SAndroid Build Coastguard Worker B = torch.mm(U.t().conj(), U) 655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker # test Lower Triangular 658*da0073e9SAndroid Build Coastguard Worker L = torch.cholesky(A, False) 659*da0073e9SAndroid Build Coastguard Worker B = torch.mm(L, L.t().conj()) 660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 663*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 664*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 665*da0073e9SAndroid Build Coastguard Worker def test_old_cholesky_empty(self, device, dtype): 666*da0073e9SAndroid Build Coastguard Worker def run_test(upper): 667*da0073e9SAndroid Build Coastguard Worker A = torch.empty(0, 0, dtype=dtype, device=device) 668*da0073e9SAndroid Build Coastguard Worker chol = torch.cholesky(A, upper) 669*da0073e9SAndroid Build Coastguard Worker chol_A = torch.matmul(chol, chol.t().conj()) 670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, chol_A) 671*da0073e9SAndroid Build Coastguard Worker for upper in [True, False]: 672*da0073e9SAndroid Build Coastguard Worker run_test(upper) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker # Test for issue 675*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/57032 676*da0073e9SAndroid Build Coastguard Worker # torch.cholesky with upper=True for batched CUDA inputs was wrong 677*da0073e9SAndroid Build Coastguard Worker # it was using the lower triangular part instead of the upper one 678*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 679*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 680*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 681*da0073e9SAndroid Build Coastguard Worker def test_old_cholesky_batched_upper(self, device, dtype): 682*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker batchsize = 2 685*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(3, batchsize, dtype=dtype, device=device) 686*da0073e9SAndroid Build Coastguard Worker A_triu = A.triu() # fill the lower triangular part with zero 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker U = torch.cholesky(A_triu, upper=True) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker reconstruct_A = U.mH @ U 691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, reconstruct_A) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 694*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 695*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 696*da0073e9SAndroid Build Coastguard Worker def test_cholesky_ex(self, device, dtype): 697*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker def run_test(n, batch): 700*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device) 701*da0073e9SAndroid Build Coastguard Worker expected_L = np.linalg.cholesky(A.cpu().numpy()) 702*da0073e9SAndroid Build Coastguard Worker expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) 703*da0073e9SAndroid Build Coastguard Worker actual_L, actual_info = torch.linalg.cholesky_ex(A) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker # For fp32 individual entries in matrices can differ between PyTorch and NumPy 706*da0073e9SAndroid Build Coastguard Worker # Let's compare the norms of matrices instead 707*da0073e9SAndroid Build Coastguard Worker if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: 708*da0073e9SAndroid Build Coastguard Worker # axis is specified to calculate matrix norm for batched input 709*da0073e9SAndroid Build Coastguard Worker expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) 710*da0073e9SAndroid Build Coastguard Worker actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) 711*da0073e9SAndroid Build Coastguard Worker # Compare the norms with standard tolerances 712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_norm, expected_norm) 713*da0073e9SAndroid Build Coastguard Worker # and individual values with a higher tolerance 714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) 715*da0073e9SAndroid Build Coastguard Worker else: 716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_L, expected_L) 717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_info, expected_info) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker ns = (0, 3, 5) 720*da0073e9SAndroid Build Coastguard Worker batches = ((), (2, ), (2, 1)) 721*da0073e9SAndroid Build Coastguard Worker for n, batch in itertools.product(ns, batches): 722*da0073e9SAndroid Build Coastguard Worker run_test(n, batch) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 725*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 726*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 727*da0073e9SAndroid Build Coastguard Worker def test_cholesky_ex_non_pd(self, device, dtype): 728*da0073e9SAndroid Build Coastguard Worker # if the input matrix is not positive definite, info with positive integer is returned 729*da0073e9SAndroid Build Coastguard Worker A = torch.eye(3, 3, dtype=dtype, device=device) 730*da0073e9SAndroid Build Coastguard Worker A[-1, -1] = 0 # Now A is singular 731*da0073e9SAndroid Build Coastguard Worker _, info = torch.linalg.cholesky_ex(A) 732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(info, 3) 733*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'): 734*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky_ex(A, check_errors=True) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker # if at least one matrix in the batch is not positive definite, 737*da0073e9SAndroid Build Coastguard Worker # batched info with positive integer for the corresponding matrix is returned 738*da0073e9SAndroid Build Coastguard Worker A = torch.eye(3, 3, dtype=dtype, device=device) 739*da0073e9SAndroid Build Coastguard Worker A = A.reshape((1, 3, 3)) 740*da0073e9SAndroid Build Coastguard Worker A = A.repeat(5, 1, 1) 741*da0073e9SAndroid Build Coastguard Worker A[3, -2, -2] = 0 # Now A[3] is singular 742*da0073e9SAndroid Build Coastguard Worker _, info = torch.linalg.cholesky_ex(A) 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) 745*da0073e9SAndroid Build Coastguard Worker expected_info[3] = 2 746*da0073e9SAndroid Build Coastguard Worker self.assertEqual(info, expected_info) 747*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'): 748*da0073e9SAndroid Build Coastguard Worker torch.linalg.cholesky_ex(A, check_errors=True) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1): 751*da0073e9SAndroid Build Coastguard Worker def check(m, a, b, beta, alpha): 752*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 753*da0073e9SAndroid Build Coastguard Worker a_np = a.to(torch.double).cpu().numpy() 754*da0073e9SAndroid Build Coastguard Worker b_np = b.to(torch.double).cpu().numpy() 755*da0073e9SAndroid Build Coastguard Worker m_np = m.to(torch.double).cpu().numpy() 756*da0073e9SAndroid Build Coastguard Worker exact_dtype = False 757*da0073e9SAndroid Build Coastguard Worker else: 758*da0073e9SAndroid Build Coastguard Worker a_np = a.cpu().numpy() 759*da0073e9SAndroid Build Coastguard Worker b_np = b.cpu().numpy() 760*da0073e9SAndroid Build Coastguard Worker m_np = m.cpu().numpy() 761*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 762*da0073e9SAndroid Build Coastguard Worker if beta == 0: 763*da0073e9SAndroid Build Coastguard Worker expected = alpha * np.outer(a_np, b_np) 764*da0073e9SAndroid Build Coastguard Worker else: 765*da0073e9SAndroid Build Coastguard Worker expected = beta * m_np + alpha * np.outer(a_np, b_np) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker res = torch.addr(m, a, b, beta=beta, alpha=alpha) 768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, exact_dtype=exact_dtype) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker # Test out variant 771*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(res) 772*da0073e9SAndroid Build Coastguard Worker torch.addr(m, a, b, beta=beta, alpha=alpha, out=out) 773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected, exact_dtype=exact_dtype) 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker m = make_tensor((50, 50), device=device, dtype=dtype, low=-2, high=2) 776*da0073e9SAndroid Build Coastguard Worker a = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2) 777*da0073e9SAndroid Build Coastguard Worker b = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2) 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker check(m, a, b, beta, alpha) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker # test transpose 782*da0073e9SAndroid Build Coastguard Worker m_transpose = torch.transpose(m, 0, 1) 783*da0073e9SAndroid Build Coastguard Worker check(m_transpose, a, b, beta, alpha) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker # test 0 strided tensor 786*da0073e9SAndroid Build Coastguard Worker zero_strided = make_tensor((1,), device=device, dtype=dtype, low=-2, high=2).expand(50) 787*da0073e9SAndroid Build Coastguard Worker check(m, zero_strided, b, beta, alpha) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker # test scalar 790*da0073e9SAndroid Build Coastguard Worker m_scalar = torch.tensor(1, device=device, dtype=dtype) 791*da0073e9SAndroid Build Coastguard Worker check(m_scalar, a, b, beta, alpha) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker # test nans and infs are not propagated to the output when beta == 0 794*da0073e9SAndroid Build Coastguard Worker float_and_complex_dtypes = floating_and_complex_types_and(torch.half, torch.bfloat16) 795*da0073e9SAndroid Build Coastguard Worker if beta == 0 and dtype in float_and_complex_dtypes: 796*da0073e9SAndroid Build Coastguard Worker m[0][10] = m[10][10] = m[20][20] = float('inf') 797*da0073e9SAndroid Build Coastguard Worker m[1][10] = m[11][10] = m[21][20] = float('nan') 798*da0073e9SAndroid Build Coastguard Worker check(m, a, b, 0, alpha) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bool) 801*da0073e9SAndroid Build Coastguard Worker def test_addr_bool(self, device, dtype): 802*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=True, alpha=False) 803*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=False, alpha=True) 804*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False) 805*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker @dtypes(*integral_types()) 808*da0073e9SAndroid Build Coastguard Worker def test_addr_integral(self, device, dtype): 809*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 810*da0073e9SAndroid Build Coastguard Worker 'argument beta must not be a floating point number.'): 811*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=2., alpha=1) 812*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 813*da0073e9SAndroid Build Coastguard Worker 'argument alpha must not be a floating point number.'): 814*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=2, alpha=1.) 815*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 816*da0073e9SAndroid Build Coastguard Worker 'Boolean beta only supported for Boolean results.'): 817*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1) 818*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 819*da0073e9SAndroid Build Coastguard Worker 'Boolean alpha only supported for Boolean results.'): 820*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker # when beta is zero 823*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=0, alpha=2) 824*da0073e9SAndroid Build Coastguard Worker # when beta is not zero 825*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2) 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-1}) 828*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16)) 829*da0073e9SAndroid Build Coastguard Worker def test_addr_float_and_complex(self, device, dtype): 830*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 831*da0073e9SAndroid Build Coastguard Worker 'Boolean beta only supported for Boolean results.'): 832*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1) 833*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 834*da0073e9SAndroid Build Coastguard Worker 'Boolean alpha only supported for Boolean results.'): 835*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker # when beta is zero 838*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2) 839*da0073e9SAndroid Build Coastguard Worker # when beta is not zero 840*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2) 841*da0073e9SAndroid Build Coastguard Worker if dtype in complex_types(): 842*da0073e9SAndroid Build Coastguard Worker self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j)) 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker @dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 845*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) 846*da0073e9SAndroid Build Coastguard Worker def test_outer_type_promotion(self, device, dtypes): 847*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5).to(device=device, dtype=dtypes[0]) 848*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5).to(device=device, dtype=dtypes[1]) 849*da0073e9SAndroid Build Coastguard Worker for op in (torch.outer, torch.Tensor.outer, torch.ger, torch.Tensor.ger): 850*da0073e9SAndroid Build Coastguard Worker result = op(a, b) 851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.result_type(a, b)) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker # don't use @dtypes decorator to avoid generating ~1700 tests per device 854*da0073e9SAndroid Build Coastguard Worker def test_addr_type_promotion(self, device): 855*da0073e9SAndroid Build Coastguard Worker for dtypes0, dtypes1, dtypes2 in product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), repeat=3): 856*da0073e9SAndroid Build Coastguard Worker a = make_tensor((5,), device=device, dtype=dtypes0, low=-2, high=2) 857*da0073e9SAndroid Build Coastguard Worker b = make_tensor((5,), device=device, dtype=dtypes1, low=-2, high=2) 858*da0073e9SAndroid Build Coastguard Worker m = make_tensor((5, 5), device=device, dtype=dtypes2, low=-2, high=2) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker desired_dtype = torch.promote_types(torch.promote_types(dtypes0, dtypes1), 861*da0073e9SAndroid Build Coastguard Worker dtypes2) 862*da0073e9SAndroid Build Coastguard Worker for op in (torch.addr, torch.Tensor.addr): 863*da0073e9SAndroid Build Coastguard Worker result = op(m, a, b) 864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, desired_dtype) 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker # Tests migrated from test_torch.py 867*da0073e9SAndroid Build Coastguard Worker # 1) test the shape of the result tensor when there is empty input tensor 868*da0073e9SAndroid Build Coastguard Worker # 2) test the Runtime Exception when there is scalar input tensor 869*da0073e9SAndroid Build Coastguard Worker def test_outer_ger_addr_legacy_tests(self, device): 870*da0073e9SAndroid Build Coastguard Worker for size in ((0, 0), (0, 5), (5, 0)): 871*da0073e9SAndroid Build Coastguard Worker a = torch.rand(size[0], device=device) 872*da0073e9SAndroid Build Coastguard Worker b = torch.rand(size[1], device=device) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.outer(a, b).shape, size) 875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ger(a, b).shape, size) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker m = torch.empty(size, device=device) 878*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.addr(m, a, b).shape, size) 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker m = torch.randn(5, 6, device=device) 881*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, device=device) 882*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(6, device=device) 883*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.outer(a, b)) 884*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.outer(b, a)) 885*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.ger(a, b)) 886*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.ger(b, a)) 887*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.addr(m, a, b)) 888*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.addr(m, b, a)) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker # Tests torch.det and its alias, torch.linalg.det, vs. NumPy 891*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 892*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 893*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 894*da0073e9SAndroid Build Coastguard Worker def test_det(self, device, dtype): 895*da0073e9SAndroid Build Coastguard Worker tensors = ( 896*da0073e9SAndroid Build Coastguard Worker torch.randn((2, 2), device=device, dtype=dtype), 897*da0073e9SAndroid Build Coastguard Worker torch.randn((129, 129), device=device, dtype=dtype), 898*da0073e9SAndroid Build Coastguard Worker torch.randn((3, 52, 52), device=device, dtype=dtype), 899*da0073e9SAndroid Build Coastguard Worker torch.randn((4, 2, 26, 26), device=device, dtype=dtype)) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker ops = (torch.det, torch.Tensor.det, 903*da0073e9SAndroid Build Coastguard Worker torch.linalg.det) 904*da0073e9SAndroid Build Coastguard Worker for t in tensors: 905*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.det(t.cpu().numpy()) 906*da0073e9SAndroid Build Coastguard Worker for op in ops: 907*da0073e9SAndroid Build Coastguard Worker actual = op(t) 908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 909*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(op, np.linalg.det, t) 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker # NOTE: det requires a 2D+ tensor 912*da0073e9SAndroid Build Coastguard Worker t = torch.randn(1, device=device, dtype=dtype) 913*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 914*da0073e9SAndroid Build Coastguard Worker op(t) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 917*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 918*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 919*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 920*da0073e9SAndroid Build Coastguard Worker def test_eigh(self, device, dtype): 921*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_matrix 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker def run_test(shape, batch, uplo): 924*da0073e9SAndroid Build Coastguard Worker matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) 925*da0073e9SAndroid Build Coastguard Worker expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) 926*da0073e9SAndroid Build Coastguard Worker actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) 927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_w, expected_w) 928*da0073e9SAndroid Build Coastguard Worker # sign of eigenvectors is not unique and therefore absolute values are compared 929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(abs(actual_v), abs(expected_v)) 930*da0073e9SAndroid Build Coastguard Worker # additionally we can multiply the eigenvector with a phase factor e^{i\phi} and then compare the values 931*da0073e9SAndroid Build Coastguard Worker # let's choose the convention that the first element of the eigenvectors from torch and numpy be the same 932*da0073e9SAndroid Build Coastguard Worker # for real inputs, this phase factor is plus or minus one 933*da0073e9SAndroid Build Coastguard Worker if matrix.numel() > 0: 934*da0073e9SAndroid Build Coastguard Worker phase = torch.from_numpy(expected_v[..., 0, :]).to(device=device).div(actual_v[..., 0, :]) 935*da0073e9SAndroid Build Coastguard Worker actual_v_rotated = actual_v * phase.unsqueeze(-2).expand_as(actual_v) 936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_v_rotated, expected_v) 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker # check the out= variant 939*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty_like(actual_w) 940*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty_like(actual_v) 941*da0073e9SAndroid Build Coastguard Worker ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v)) 942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans_w, out_w) 943*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans_v, out_v) 944*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans_w, actual_w) 945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(abs(ans_v), abs(actual_v)) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker shapes = (0, 3, 5) 948*da0073e9SAndroid Build Coastguard Worker batches = ((), (3, ), (2, 2)) 949*da0073e9SAndroid Build Coastguard Worker uplos = ["U", "L"] 950*da0073e9SAndroid Build Coastguard Worker for shape, batch, uplo in itertools.product(shapes, batches, uplos): 951*da0073e9SAndroid Build Coastguard Worker run_test(shape, batch, uplo) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 954*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 955*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 956*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 957*da0073e9SAndroid Build Coastguard Worker def test_eigh_lower_uplo(self, device, dtype): 958*da0073e9SAndroid Build Coastguard Worker def run_test(shape, batch, uplo): 959*da0073e9SAndroid Build Coastguard Worker # check lower case uplo 960*da0073e9SAndroid Build Coastguard Worker # use non-symmetric input to check whether uplo argument is working as intended 961*da0073e9SAndroid Build Coastguard Worker matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device) 962*da0073e9SAndroid Build Coastguard Worker expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) 963*da0073e9SAndroid Build Coastguard Worker actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) 964*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_w, expected_w) 965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(abs(actual_v), abs(expected_v)) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker uplos = ["u", "l"] 968*da0073e9SAndroid Build Coastguard Worker for uplo in uplos: 969*da0073e9SAndroid Build Coastguard Worker run_test(3, (2, 2), uplo) 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 972*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 973*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 974*da0073e9SAndroid Build Coastguard Worker def test_eigh_errors_and_warnings(self, device, dtype): 975*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_matrix 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker # eigh requires a square matrix 978*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 3, device=device, dtype=dtype) 979*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 980*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigh(t) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker # eigh requires 'uplo' parameter to be 'U' or 'L' 983*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 3, device=device, dtype=dtype) 984*da0073e9SAndroid Build Coastguard Worker for uplo in ["a", "wrong"]: 985*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): 986*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigh(t, UPLO=uplo) 987*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): 988*da0073e9SAndroid Build Coastguard Worker np.linalg.eigh(t.cpu().numpy(), UPLO=uplo) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 991*da0073e9SAndroid Build Coastguard Worker a = random_hermitian_matrix(3, dtype=dtype, device=device) 992*da0073e9SAndroid Build Coastguard Worker real_dtype = a.real.dtype if dtype.is_complex else dtype 993*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(7, 7, dtype=real_dtype, device=device) 994*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty(7, 7, dtype=dtype, device=device) 995*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 996*da0073e9SAndroid Build Coastguard Worker # Trigger warning 997*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigh(a, out=(out_w, out_v)) 998*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 1000*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-2].message)) 1001*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 1004*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(0, dtype=real_dtype, device=device) 1005*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty(0, dtype=torch.int, device=device) 1006*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got int instead"): 1007*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigh(a, out=(out_w, out_v)) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(0, dtype=torch.int, device=device) 1010*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty(0, dtype=dtype, device=device) 1011*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got int instead"): 1012*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigh(a, out=(out_w, out_v)) 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker # device should match 1015*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 1016*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 1017*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(0, device=wrong_device, dtype=dtype) 1018*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty(0, device=device, dtype=dtype) 1019*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1020*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigh(a, out=(out_w, out_v)) 1021*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(0, device=device, dtype=dtype) 1022*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty(0, device=wrong_device, dtype=dtype) 1023*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1024*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigh(a, out=(out_w, out_v)) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1027*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1028*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.") 1029*da0073e9SAndroid Build Coastguard Worker def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype): 1030*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/94772, https://github.com/pytorch/pytorch/issues/105359 1031*da0073e9SAndroid Build Coastguard Worker # This test crashes with `cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED` on cuda 11.8, 1032*da0073e9SAndroid Build Coastguard Worker # but passes on cuda 12.1 update 1 or later. 1033*da0073e9SAndroid Build Coastguard Worker a = torch.ones(512, 512, dtype=dtype, device=device) 1034*da0073e9SAndroid Build Coastguard Worker a[0, 0] = 1.0e-5 1035*da0073e9SAndroid Build Coastguard Worker a[-1, -1] = 1.0e5 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker eigh_out = torch.linalg.eigh(a) 1038*da0073e9SAndroid Build Coastguard Worker svd_out = torch.linalg.svd(a) 1039*da0073e9SAndroid Build Coastguard Worker 1040*da0073e9SAndroid Build Coastguard Worker # Matrix input a is too ill-conditioned. 1041*da0073e9SAndroid Build Coastguard Worker # We'll just compare the first two singular values/eigenvalues. They are 1.0e5 and 511.0 1042*da0073e9SAndroid Build Coastguard Worker # The precision override with tolerance of 1.0 makes sense since ill-conditioned inputs are hard to converge 1043*da0073e9SAndroid Build Coastguard Worker # to exact values. 1044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2) 1045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2) 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1048*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1049*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1050*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) 1051*da0073e9SAndroid Build Coastguard Worker def test_eigvalsh(self, device, dtype): 1052*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_matrix 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker def run_test(shape, batch, uplo): 1055*da0073e9SAndroid Build Coastguard Worker matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) 1056*da0073e9SAndroid Build Coastguard Worker expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo) 1057*da0073e9SAndroid Build Coastguard Worker actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo) 1058*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_w, expected_w) 1059*da0073e9SAndroid Build Coastguard Worker 1060*da0073e9SAndroid Build Coastguard Worker # check the out= variant 1061*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(actual_w) 1062*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out) 1063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 1064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, actual_w) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker shapes = (0, 3, 5) 1067*da0073e9SAndroid Build Coastguard Worker batches = ((), (3, ), (2, 2)) 1068*da0073e9SAndroid Build Coastguard Worker uplos = ["U", "L"] 1069*da0073e9SAndroid Build Coastguard Worker for shape, batch, uplo in itertools.product(shapes, batches, uplos): 1070*da0073e9SAndroid Build Coastguard Worker run_test(shape, batch, uplo) 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1073*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1074*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1075*da0073e9SAndroid Build Coastguard Worker def test_eigvalsh_errors_and_warnings(self, device, dtype): 1076*da0073e9SAndroid Build Coastguard Worker # eigvalsh requires a square matrix 1077*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 3, device=device, dtype=dtype) 1078*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 1079*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvalsh(t) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker # eigvalsh requires 'uplo' parameter to be 'U' or 'L' 1082*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 3, device=device, dtype=dtype) 1083*da0073e9SAndroid Build Coastguard Worker for uplo in ["a", "wrong"]: 1084*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): 1085*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvalsh(t, UPLO=uplo) 1086*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): 1087*da0073e9SAndroid Build Coastguard Worker np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo) 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 1090*da0073e9SAndroid Build Coastguard Worker real_dtype = t.real.dtype if dtype.is_complex else dtype 1091*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(t).to(real_dtype) 1092*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1093*da0073e9SAndroid Build Coastguard Worker # Trigger warning 1094*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvalsh(t, out=out) 1095*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 1096*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 1097*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 1100*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.int, device=device) 1101*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got int instead"): 1102*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvalsh(t, out=out) 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker # device should match 1105*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 1106*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 1107*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=wrong_device, dtype=dtype) 1108*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1109*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvalsh(t, out=out) 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1112*da0073e9SAndroid Build Coastguard Worker def test_kron(self, device, dtype): 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Worker def run_test_case(a_shape, b_shape): 1115*da0073e9SAndroid Build Coastguard Worker a = torch.rand(a_shape, dtype=dtype, device=device) 1116*da0073e9SAndroid Build Coastguard Worker b = torch.rand(b_shape, dtype=dtype, device=device) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) 1119*da0073e9SAndroid Build Coastguard Worker result = torch.kron(a, b) 1120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker # check the out= variant 1123*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(result) 1124*da0073e9SAndroid Build Coastguard Worker ans = torch.kron(a, b, out=out) 1125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 1126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, result) 1127*da0073e9SAndroid Build Coastguard Worker 1128*da0073e9SAndroid Build Coastguard Worker shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] 1129*da0073e9SAndroid Build Coastguard Worker for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): 1130*da0073e9SAndroid Build Coastguard Worker run_test_case(a_shape, b_shape) 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1133*da0073e9SAndroid Build Coastguard Worker def test_kron_empty(self, device, dtype): 1134*da0073e9SAndroid Build Coastguard Worker 1135*da0073e9SAndroid Build Coastguard Worker def run_test_case(empty_shape): 1136*da0073e9SAndroid Build Coastguard Worker a = torch.eye(3, dtype=dtype, device=device) 1137*da0073e9SAndroid Build Coastguard Worker b = torch.empty(empty_shape, dtype=dtype, device=device) 1138*da0073e9SAndroid Build Coastguard Worker result = torch.kron(a, b) 1139*da0073e9SAndroid Build Coastguard Worker expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) 1140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker # NumPy doesn't work if the first argument is empty 1143*da0073e9SAndroid Build Coastguard Worker result = torch.kron(b, a) 1144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, expected.shape) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker empty_shapes = [(0,), (2, 0), (1, 0, 3)] 1147*da0073e9SAndroid Build Coastguard Worker for empty_shape in empty_shapes: 1148*da0073e9SAndroid Build Coastguard Worker run_test_case(empty_shape) 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1151*da0073e9SAndroid Build Coastguard Worker def test_kron_errors_and_warnings(self, device, dtype): 1152*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 1153*da0073e9SAndroid Build Coastguard Worker a = torch.eye(3, dtype=dtype, device=device) 1154*da0073e9SAndroid Build Coastguard Worker b = torch.ones((2, 2), dtype=dtype, device=device) 1155*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a) 1156*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1157*da0073e9SAndroid Build Coastguard Worker # Trigger warning 1158*da0073e9SAndroid Build Coastguard Worker torch.kron(a, b, out=out) 1159*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 1160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 1161*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1162*da0073e9SAndroid Build Coastguard Worker 1163*da0073e9SAndroid Build Coastguard Worker # dtypes should match 1164*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a).to(torch.int) 1165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): 1166*da0073e9SAndroid Build Coastguard Worker torch.kron(a, b, out=out) 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker # This test confirms that torch.linalg.norm's dtype argument works 1169*da0073e9SAndroid Build Coastguard Worker # as expected, according to the function's documentation 1170*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16) 1171*da0073e9SAndroid Build Coastguard Worker def test_norm_dtype(self, device, dtype): 1172*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 1173*da0073e9SAndroid Build Coastguard Worker 1174*da0073e9SAndroid Build Coastguard Worker def run_test_case(input_size, ord, keepdim, to_dtype): 1175*da0073e9SAndroid Build Coastguard Worker msg = ( 1176*da0073e9SAndroid Build Coastguard Worker f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' 1177*da0073e9SAndroid Build Coastguard Worker f'dtype={dtype}, to_dtype={to_dtype}') 1178*da0073e9SAndroid Build Coastguard Worker input = make_arg(input_size) 1179*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input, ord, keepdim=keepdim) 1180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, input.real.dtype, msg=msg) 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker result_out = torch.empty((0), dtype=result.dtype, device=device) 1183*da0073e9SAndroid Build Coastguard Worker torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out) 1184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_out, msg=msg) 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim) 1187*da0073e9SAndroid Build Coastguard Worker result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) 1188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_with_dtype, msg=msg) 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker result_out_with_dtype = torch.empty_like(result_with_dtype) 1191*da0073e9SAndroid Build Coastguard Worker torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype) 1192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg) 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker # In these orders we are computing the 10-th power and 10-th root of numbers. 1197*da0073e9SAndroid Build Coastguard Worker # We avoid them for half-precision types as it makes the tests above too badly conditioned 1198*da0073e9SAndroid Build Coastguard Worker if dtype != torch.float16 and dtype != torch.bfloat16: 1199*da0073e9SAndroid Build Coastguard Worker ord_vector.extend([0.1, -0.1]) 1200*da0073e9SAndroid Build Coastguard Worker ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None] 1201*da0073e9SAndroid Build Coastguard Worker S = 10 1202*da0073e9SAndroid Build Coastguard Worker 1203*da0073e9SAndroid Build Coastguard Worker if dtype == torch.cfloat: 1204*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (torch.cfloat, torch.cdouble) 1205*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.cdouble: 1206*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (torch.cdouble,) 1207*da0073e9SAndroid Build Coastguard Worker elif dtype in (torch.float16, torch.bfloat16, torch.float): 1208*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (torch.float, torch.double) 1209*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.double: 1210*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (torch.double,) 1211*da0073e9SAndroid Build Coastguard Worker else: 1212*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unsupported dtype") 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker for ord, keepdim, norm_dtype in product(ord_vector, (True, False), norm_dtypes): 1215*da0073e9SAndroid Build Coastguard Worker run_test_case((S,) , ord, keepdim, norm_dtype) 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker for ord, keepdim, norm_dtype in product(ord_matrix, (True, False), norm_dtypes): 1218*da0073e9SAndroid Build Coastguard Worker if ord in [2, -2, 'nuc']: 1219*da0073e9SAndroid Build Coastguard Worker # We need torch.svdvals 1220*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16 or dtype == torch.bfloat16: 1221*da0073e9SAndroid Build Coastguard Worker continue 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker # We need LAPACK or equivalent 1224*da0073e9SAndroid Build Coastguard Worker if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or 1225*da0073e9SAndroid Build Coastguard Worker (torch.device(device).type == 'cpu' and not torch._C.has_lapack)): 1226*da0073e9SAndroid Build Coastguard Worker continue 1227*da0073e9SAndroid Build Coastguard Worker run_test_case((S, S) , ord, keepdim, norm_dtype) 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker # This test confirms torch.linalg.norm bfloat16 and half get right result. 1230*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float16) 1231*da0073e9SAndroid Build Coastguard Worker def test_norm_bfloat16_and_half(self, device, dtype): 1232*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 1233*da0073e9SAndroid Build Coastguard Worker 1234*da0073e9SAndroid Build Coastguard Worker def run_test_case(input_size, ord, keepdim): 1235*da0073e9SAndroid Build Coastguard Worker msg = ( 1236*da0073e9SAndroid Build Coastguard Worker f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' 1237*da0073e9SAndroid Build Coastguard Worker f'dtype={dtype}') 1238*da0073e9SAndroid Build Coastguard Worker input = make_arg(input_size).fill_(1) 1239*da0073e9SAndroid Build Coastguard Worker result_ref = torch.linalg.norm(input.float(), ord, keepdim=keepdim).to(dtype=dtype) 1240*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input, ord, keepdim=keepdim) 1241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_ref, result, msg=msg) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] 1244*da0073e9SAndroid Build Coastguard Worker for S, ord, keepdim in product((10, 2049), ord_vector, (True, False)): 1245*da0073e9SAndroid Build Coastguard Worker run_test_case((S,) , ord, keepdim, ) 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16) 1248*da0073e9SAndroid Build Coastguard Worker def test_vector_norm(self, device, dtype): 1249*da0073e9SAndroid Build Coastguard Worker if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]: 1250*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438") 1251*da0073e9SAndroid Build Coastguard Worker # have to use torch.randn(...).to(bfloat16) instead of 1252*da0073e9SAndroid Build Coastguard Worker # This test compares torch.linalg.vector_norm's output with 1253*da0073e9SAndroid Build Coastguard Worker # torch.linalg.norm given a flattened tensor 1254*da0073e9SAndroid Build Coastguard Worker ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf] 1255*da0073e9SAndroid Build Coastguard Worker input_sizes = [ 1256*da0073e9SAndroid Build Coastguard Worker (1, ), 1257*da0073e9SAndroid Build Coastguard Worker (10, ), 1258*da0073e9SAndroid Build Coastguard Worker (4, 5), 1259*da0073e9SAndroid Build Coastguard Worker (3, 4, 5), 1260*da0073e9SAndroid Build Coastguard Worker (0, ), 1261*da0073e9SAndroid Build Coastguard Worker (0, 10), 1262*da0073e9SAndroid Build Coastguard Worker (0, 0), 1263*da0073e9SAndroid Build Coastguard Worker (10, 0, 10), 1264*da0073e9SAndroid Build Coastguard Worker ] 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker def vector_norm_reference(input, ord, dim=None, keepdim=False, dtype=None): 1267*da0073e9SAndroid Build Coastguard Worker if dim is None: 1268*da0073e9SAndroid Build Coastguard Worker input_maybe_flat = input.flatten(0, -1) 1269*da0073e9SAndroid Build Coastguard Worker else: 1270*da0073e9SAndroid Build Coastguard Worker input_maybe_flat = input 1271*da0073e9SAndroid Build Coastguard Worker 1272*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input_maybe_flat, ord, dim=dim, keepdim=keepdim, dtype=dtype) 1273*da0073e9SAndroid Build Coastguard Worker if keepdim and dim is None: 1274*da0073e9SAndroid Build Coastguard Worker result = result.reshape([1] * input.dim()) 1275*da0073e9SAndroid Build Coastguard Worker return result 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker def run_test_case(input, ord, dim, keepdim, norm_dtype): 1278*da0073e9SAndroid Build Coastguard Worker if (input.numel() == 0 and 1279*da0073e9SAndroid Build Coastguard Worker (ord < 0. or ord == inf) and 1280*da0073e9SAndroid Build Coastguard Worker (dim is None or input.shape[dim] == 0)): 1281*da0073e9SAndroid Build Coastguard Worker # The operation does not have an identity. 1282*da0073e9SAndroid Build Coastguard Worker error_msg = "linalg.vector_norm cannot compute" 1283*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 1284*da0073e9SAndroid Build Coastguard Worker torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim) 1285*da0073e9SAndroid Build Coastguard Worker else: 1286*da0073e9SAndroid Build Coastguard Worker msg = (f'input.size()={input.size()}, ord={ord}, dim={dim}, ' 1287*da0073e9SAndroid Build Coastguard Worker f'keepdim={keepdim}, dtype={dtype}, norm_dtype={norm_dtype}') 1288*da0073e9SAndroid Build Coastguard Worker result_dtype_reference = vector_norm_reference(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype) 1289*da0073e9SAndroid Build Coastguard Worker result_dtype = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype) 1290*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 1291*da0073e9SAndroid Build Coastguard Worker result_dtype_reference = result_dtype_reference.real 1292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_dtype, result_dtype_reference, msg=msg) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker if norm_dtype is not None: 1295*da0073e9SAndroid Build Coastguard Worker ref = torch.linalg.vector_norm(input.to(norm_dtype), ord, dim=dim, keepdim=keepdim) 1296*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype) 1297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, actual, msg=msg) 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker if dtype == torch.cfloat: 1300*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (None, torch.cfloat, torch.cdouble) 1301*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.cdouble: 1302*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (None, torch.cdouble) 1303*da0073e9SAndroid Build Coastguard Worker elif dtype in (torch.float16, torch.bfloat16, torch.float): 1304*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (None, torch.float, torch.double) 1305*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.double: 1306*da0073e9SAndroid Build Coastguard Worker norm_dtypes = (None, torch.double) 1307*da0073e9SAndroid Build Coastguard Worker else: 1308*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unsupported dtype") 1309*da0073e9SAndroid Build Coastguard Worker 1310*da0073e9SAndroid Build Coastguard Worker for amp in [False, True]: 1311*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=device, enabled=amp): 1312*da0073e9SAndroid Build Coastguard Worker for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes): 1313*da0073e9SAndroid Build Coastguard Worker input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1314*da0073e9SAndroid Build Coastguard Worker for dim in [None, random.randint(0, len(input_size) - 1)]: 1315*da0073e9SAndroid Build Coastguard Worker run_test_case( 1316*da0073e9SAndroid Build Coastguard Worker input, 1317*da0073e9SAndroid Build Coastguard Worker ord, 1318*da0073e9SAndroid Build Coastguard Worker dim, 1319*da0073e9SAndroid Build Coastguard Worker keepdim, 1320*da0073e9SAndroid Build Coastguard Worker norm_dtype) 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker def test_vector_norm_dim_tuple_arg(self, device): 1323*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1324*da0073e9SAndroid Build Coastguard Worker # input size, dim, error, error message 1325*da0073e9SAndroid Build Coastguard Worker ((4, ), (0, ), None, None), 1326*da0073e9SAndroid Build Coastguard Worker ((4, ), (1, ), IndexError, r'Dimension out of range'), 1327*da0073e9SAndroid Build Coastguard Worker ((4, ), (-2, ), IndexError, r'Dimension out of range'), 1328*da0073e9SAndroid Build Coastguard Worker ((4, 3), (0, -1), None, None), 1329*da0073e9SAndroid Build Coastguard Worker ((4, 3), (0, 0), RuntimeError, r'dim 0 appears multiple times in the list of dims'), 1330*da0073e9SAndroid Build Coastguard Worker ((4, 3), (0, -2), RuntimeError, r'dim 0 appears multiple times in the list of dims'), 1331*da0073e9SAndroid Build Coastguard Worker ((4, 3), (0, 1.0), TypeError, r"argument 'dim' must be tuple of ints"), 1332*da0073e9SAndroid Build Coastguard Worker ((4, 3), (None, ), TypeError, r"argument 'dim' must be tuple of ints"), 1333*da0073e9SAndroid Build Coastguard Worker ] 1334*da0073e9SAndroid Build Coastguard Worker for input_size, dim_tuple, error, error_msg in test_cases: 1335*da0073e9SAndroid Build Coastguard Worker input = torch.randn(input_size, device=device) 1336*da0073e9SAndroid Build Coastguard Worker # vector_norm should accept a tuple or a list for dim arg 1337*da0073e9SAndroid Build Coastguard Worker for dim in [dim_tuple, list(dim_tuple)]: 1338*da0073e9SAndroid Build Coastguard Worker if error is None: 1339*da0073e9SAndroid Build Coastguard Worker torch.linalg.vector_norm(input, dim=dim) 1340*da0073e9SAndroid Build Coastguard Worker else: 1341*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(error): 1342*da0073e9SAndroid Build Coastguard Worker torch.linalg.vector_norm(input, dim=dim) 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that 1345*da0073e9SAndroid Build Coastguard Worker # their vector norm results match 1346*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1347*da0073e9SAndroid Build Coastguard Worker def test_norm_vector(self, device, dtype): 1348*da0073e9SAndroid Build Coastguard Worker def run_test_case(input, p, dim, keepdim): 1349*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input, ord, dim, keepdim) 1350*da0073e9SAndroid Build Coastguard Worker input_numpy = input.cpu().numpy() 1351*da0073e9SAndroid Build Coastguard Worker result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_numpy, msg=msg) 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker result_out = torch.empty_like(result) 1357*da0073e9SAndroid Build Coastguard Worker torch.linalg.norm(input, ord, dim, keepdim, out=result_out) 1358*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_out, msg=msg) 1359*da0073e9SAndroid Build Coastguard Worker 1360*da0073e9SAndroid Build Coastguard Worker ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf] 1361*da0073e9SAndroid Build Coastguard Worker S = 10 1362*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1363*da0073e9SAndroid Build Coastguard Worker # input size, p settings, dim 1364*da0073e9SAndroid Build Coastguard Worker ((S, ), ord_vector, None), 1365*da0073e9SAndroid Build Coastguard Worker ((S, ), ord_vector, 0), 1366*da0073e9SAndroid Build Coastguard Worker ((S, S, S), ord_vector, 0), 1367*da0073e9SAndroid Build Coastguard Worker ((S, S, S), ord_vector, 1), 1368*da0073e9SAndroid Build Coastguard Worker ((S, S, S), ord_vector, 2), 1369*da0073e9SAndroid Build Coastguard Worker ((S, S, S), ord_vector, -1), 1370*da0073e9SAndroid Build Coastguard Worker ((S, S, S), ord_vector, -2), 1371*da0073e9SAndroid Build Coastguard Worker ] 1372*da0073e9SAndroid Build Coastguard Worker L = 1_000_000 1373*da0073e9SAndroid Build Coastguard Worker if dtype == torch.double: 1374*da0073e9SAndroid Build Coastguard Worker test_cases.append(((L, ), ord_vector, None)) 1375*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 1376*da0073e9SAndroid Build Coastguard Worker for input_size, ord_settings, dim in test_cases: 1377*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1378*da0073e9SAndroid Build Coastguard Worker for ord in ord_settings: 1379*da0073e9SAndroid Build Coastguard Worker run_test_case(input, ord, dim, keepdim) 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker # This test compares torch.linalg.norm, torch.linalg.matrix_norm and numpy.linalg.norm to 1382*da0073e9SAndroid Build Coastguard Worker # ensure that their matrix norm results match. 1383*da0073e9SAndroid Build Coastguard Worker @skipMeta # https://github.com/pytorch/pytorch/issues/54082 1384*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1385*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1386*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 2e-4}) 1387*da0073e9SAndroid Build Coastguard Worker def test_norm_matrix(self, device, dtype): 1388*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker def run_test_case(input, ord, dim, keepdim): 1391*da0073e9SAndroid Build Coastguard Worker msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1392*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input, ord, dim, keepdim) 1393*da0073e9SAndroid Build Coastguard Worker input_numpy = input.cpu().numpy() 1394*da0073e9SAndroid Build Coastguard Worker result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input, ord, dim, keepdim) 1397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_numpy, msg=msg) 1398*da0073e9SAndroid Build Coastguard Worker if ord is not None and dim is not None: 1399*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.matrix_norm(input, ord, dim, keepdim) 1400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_numpy, msg=msg) 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro'] 1403*da0073e9SAndroid Build Coastguard Worker S = 10 1404*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1405*da0073e9SAndroid Build Coastguard Worker # input size, dim 1406*da0073e9SAndroid Build Coastguard Worker ((S, S), None), 1407*da0073e9SAndroid Build Coastguard Worker ((S, S), (0, 1)), 1408*da0073e9SAndroid Build Coastguard Worker ((S, S), (1, 0)), 1409*da0073e9SAndroid Build Coastguard Worker ((S, S, S, S), (2, 0)), 1410*da0073e9SAndroid Build Coastguard Worker ((S, S, S, S), (-1, -2)), 1411*da0073e9SAndroid Build Coastguard Worker ((S, S, S, S), (-1, -3)), 1412*da0073e9SAndroid Build Coastguard Worker ((S, S, S, S), (-3, 2)), 1413*da0073e9SAndroid Build Coastguard Worker ] 1414*da0073e9SAndroid Build Coastguard Worker 1415*da0073e9SAndroid Build Coastguard Worker for (shape, dim), keepdim, ord in product(test_cases, [True, False], ord_matrix): 1416*da0073e9SAndroid Build Coastguard Worker if ord in [2, -2, 'nuc']: 1417*da0073e9SAndroid Build Coastguard Worker # We need torch.svdvals 1418*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16 or dtype == torch.bfloat16: 1419*da0073e9SAndroid Build Coastguard Worker continue 1420*da0073e9SAndroid Build Coastguard Worker # We need LAPACK or equivalent 1421*da0073e9SAndroid Build Coastguard Worker if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or 1422*da0073e9SAndroid Build Coastguard Worker (torch.device(device).type == 'cpu' and not torch._C.has_lapack)): 1423*da0073e9SAndroid Build Coastguard Worker continue 1424*da0073e9SAndroid Build Coastguard Worker run_test_case(make_arg(shape), ord, dim, keepdim) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker 1427*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1428*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float16) 1429*da0073e9SAndroid Build Coastguard Worker def test_norm_fused_type_promotion(self, device, dtype): 1430*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, device=device, dtype=dtype) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker def profile_and_check(fn, x, kwargs): 1433*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p: 1434*da0073e9SAndroid Build Coastguard Worker fn(x, **kwargs, dtype=torch.float) 1435*da0073e9SAndroid Build Coastguard Worker # smoke check that profiler returned some events 1436*da0073e9SAndroid Build Coastguard Worker self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events())) 1437*da0073e9SAndroid Build Coastguard Worker # test that there was no explicit copy 1438*da0073e9SAndroid Build Coastguard Worker self.assertFalse("aten::to" in (e.name for e in p.events())) 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})): 1441*da0073e9SAndroid Build Coastguard Worker profile_and_check(f, x, kwargs) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker @skipMeta # https://github.com/pytorch/pytorch/issues/53739 1444*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1445*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1446*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1447*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3}) 1448*da0073e9SAndroid Build Coastguard Worker def test_cond(self, device, dtype): 1449*da0073e9SAndroid Build Coastguard Worker def run_test_case(input, p): 1450*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.cond(input, p) 1451*da0073e9SAndroid Build Coastguard Worker result_numpy = np.linalg.cond(input.cpu().numpy(), p) 1452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision, exact_dtype=False) 1453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, result_numpy.shape) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker # test out= variant 1456*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(result) 1457*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.cond(input, p, out=out) 1458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 1459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, result) 1460*da0073e9SAndroid Build Coastguard Worker 1461*da0073e9SAndroid Build Coastguard Worker norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] 1462*da0073e9SAndroid Build Coastguard Worker input_sizes = [(32, 32), (2, 3, 3, 3)] 1463*da0073e9SAndroid Build Coastguard Worker for input_size in input_sizes: 1464*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1465*da0073e9SAndroid Build Coastguard Worker for p in norm_types: 1466*da0073e9SAndroid Build Coastguard Worker run_test_case(input, p) 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker # test empty batch sizes 1469*da0073e9SAndroid Build Coastguard Worker input_sizes = [(0, 3, 3), (0, 2, 5, 5)] 1470*da0073e9SAndroid Build Coastguard Worker for input_size in input_sizes: 1471*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1472*da0073e9SAndroid Build Coastguard Worker for p in norm_types: 1473*da0073e9SAndroid Build Coastguard Worker run_test_case(input, p) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker # test non-square input 1476*da0073e9SAndroid Build Coastguard Worker input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)] 1477*da0073e9SAndroid Build Coastguard Worker for input_size in input_sizes: 1478*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1479*da0073e9SAndroid Build Coastguard Worker for p in [2, -2, None]: 1480*da0073e9SAndroid Build Coastguard Worker run_test_case(input, p) 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker # test for singular input 1483*da0073e9SAndroid Build Coastguard Worker a = torch.eye(3, dtype=dtype, device=device) 1484*da0073e9SAndroid Build Coastguard Worker a[-1, -1] = 0 # make 'a' singular 1485*da0073e9SAndroid Build Coastguard Worker for p in norm_types: 1486*da0073e9SAndroid Build Coastguard Worker try: 1487*da0073e9SAndroid Build Coastguard Worker run_test_case(a, p) 1488*da0073e9SAndroid Build Coastguard Worker except np.linalg.LinAlgError: 1489*da0073e9SAndroid Build Coastguard Worker # Numpy may fail to converge for some BLAS backends (although this is very rare) 1490*da0073e9SAndroid Build Coastguard Worker # See the discussion in https://github.com/pytorch/pytorch/issues/67675 1491*da0073e9SAndroid Build Coastguard Worker pass 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker # test for 0x0 matrices. NumPy doesn't work for such input, we return 0 1494*da0073e9SAndroid Build Coastguard Worker input_sizes = [(0, 0), (2, 5, 0, 0)] 1495*da0073e9SAndroid Build Coastguard Worker for input_size in input_sizes: 1496*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1497*da0073e9SAndroid Build Coastguard Worker for p in ['fro', 2]: 1498*da0073e9SAndroid Build Coastguard Worker expected_dtype = a.real.dtype if dtype.is_complex else dtype 1499*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device) 1500*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.cond(input, p) 1501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker @skipMeta # https://github.com/pytorch/pytorch/issues/53739 1504*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1505*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1506*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1507*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3}) 1508*da0073e9SAndroid Build Coastguard Worker def test_cond_errors_and_warnings(self, device, dtype): 1509*da0073e9SAndroid Build Coastguard Worker norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker # cond expects the input to be at least 2-dimensional 1512*da0073e9SAndroid Build Coastguard Worker a = torch.ones(3, dtype=dtype, device=device) 1513*da0073e9SAndroid Build Coastguard Worker for p in norm_types: 1514*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'): 1515*da0073e9SAndroid Build Coastguard Worker torch.linalg.cond(a, p) 1516*da0073e9SAndroid Build Coastguard Worker 1517*da0073e9SAndroid Build Coastguard Worker # for some norm types cond expects the input to be square 1518*da0073e9SAndroid Build Coastguard Worker a = torch.ones(3, 2, dtype=dtype, device=device) 1519*da0073e9SAndroid Build Coastguard Worker norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] 1520*da0073e9SAndroid Build Coastguard Worker for p in norm_types: 1521*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 1522*da0073e9SAndroid Build Coastguard Worker torch.linalg.cond(a, p) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 1525*da0073e9SAndroid Build Coastguard Worker a = torch.ones((2, 2), dtype=dtype, device=device) 1526*da0073e9SAndroid Build Coastguard Worker for p in ['fro', 2]: 1527*da0073e9SAndroid Build Coastguard Worker real_dtype = a.real.dtype if dtype.is_complex else dtype 1528*da0073e9SAndroid Build Coastguard Worker out = torch.empty(a.shape, dtype=real_dtype, device=device) 1529*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1530*da0073e9SAndroid Build Coastguard Worker # Trigger warning 1531*da0073e9SAndroid Build Coastguard Worker torch.linalg.cond(a, p, out=out) 1532*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 1533*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 1534*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 1537*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.int, device=device) 1538*da0073e9SAndroid Build Coastguard Worker for p in ['fro', 2]: 1539*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 1540*da0073e9SAndroid Build Coastguard Worker torch.linalg.cond(a, p, out=out) 1541*da0073e9SAndroid Build Coastguard Worker 1542*da0073e9SAndroid Build Coastguard Worker # device should match 1543*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 1544*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 1545*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=dtype, device=wrong_device) 1546*da0073e9SAndroid Build Coastguard Worker for p in ['fro', 2]: 1547*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 1548*da0073e9SAndroid Build Coastguard Worker torch.linalg.cond(a, p, out=out) 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker # for batched input if at least one matrix in the batch is not invertible, 1551*da0073e9SAndroid Build Coastguard Worker # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. 1552*da0073e9SAndroid Build Coastguard Worker # this should change when at::inverse works with silent errors 1553*da0073e9SAndroid Build Coastguard Worker # NumPy works fine in this case because it's possible to silence the error and get the inverse matrix results 1554*da0073e9SAndroid Build Coastguard Worker # possibly filled with NANs 1555*da0073e9SAndroid Build Coastguard Worker batch_dim = 3 1556*da0073e9SAndroid Build Coastguard Worker a = torch.eye(3, 3, dtype=dtype, device=device) 1557*da0073e9SAndroid Build Coastguard Worker a = a.reshape((1, 3, 3)) 1558*da0073e9SAndroid Build Coastguard Worker a = a.repeat(batch_dim, 1, 1) 1559*da0073e9SAndroid Build Coastguard Worker a[1, -1, -1] = 0 # now a[1] is singular 1560*da0073e9SAndroid Build Coastguard Worker for p in [1, -1, inf, -inf, 'fro', 'nuc']: 1561*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.cond(a, p) 1562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result[1], float('inf')) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker # check invalid norm type 1565*da0073e9SAndroid Build Coastguard Worker a = torch.ones(3, 3, dtype=dtype, device=device) 1566*da0073e9SAndroid Build Coastguard Worker for p in ['wrong_norm', 5]: 1567*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"): 1568*da0073e9SAndroid Build Coastguard Worker torch.linalg.cond(a, p) 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments 1571*da0073e9SAndroid Build Coastguard Worker # to ensure that they both throw errors 1572*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1573*da0073e9SAndroid Build Coastguard Worker def test_norm_errors(self, device, dtype): 1574*da0073e9SAndroid Build Coastguard Worker def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): 1575*da0073e9SAndroid Build Coastguard Worker test_case_info = ( 1576*da0073e9SAndroid Build Coastguard Worker f'test case input.size()={input.size()}, ord={ord}, dim={dim}, ' 1577*da0073e9SAndroid Build Coastguard Worker f'keepdim={keepdim}, dtype={dtype}') 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info): 1580*da0073e9SAndroid Build Coastguard Worker torch.linalg.norm(input, ord, dim, keepdim) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker input_numpy = input.cpu().numpy() 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"' 1585*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception, msg=test_case_info): 1586*da0073e9SAndroid Build Coastguard Worker np.linalg.norm(input_numpy, ord, dim, keepdim) 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker S = 10 1589*da0073e9SAndroid Build Coastguard Worker error_test_cases = [ 1590*da0073e9SAndroid Build Coastguard Worker # input size, p settings, dim, error type, error regex 1591*da0073e9SAndroid Build Coastguard Worker ((S, ), ['fro', 'nuc'], None, RuntimeError, r'A must have at least 2 dimensions'), 1592*da0073e9SAndroid Build Coastguard Worker ((S, S), [3.5], None, RuntimeError, r'matrix_norm: Order 3.5 not supported'), 1593*da0073e9SAndroid Build Coastguard Worker ((S, S), [0], None, RuntimeError, r'matrix_norm: Order 0 not supported'), 1594*da0073e9SAndroid Build Coastguard Worker ((S, S), ['fail'], None, RuntimeError, r'matrix_norm: Order fail not supported'), 1595*da0073e9SAndroid Build Coastguard Worker ((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple'), 1596*da0073e9SAndroid Build Coastguard Worker ((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'dims must be different'), 1597*da0073e9SAndroid Build Coastguard Worker ((S, S), ['fro', 'nuc', 2], (-1, 1), RuntimeError, r'dims must be different'), 1598*da0073e9SAndroid Build Coastguard Worker ((S, S), ['fro', 'nuc', 2], (0, 4), IndexError, r'Dimension out of range'), 1599*da0073e9SAndroid Build Coastguard Worker ((S, ), [0], (4, ), IndexError, r'Dimension out of range'), 1600*da0073e9SAndroid Build Coastguard Worker ((S, ), [None], (0, 0), RuntimeError, r'dim 0 appears multiple times'), 1601*da0073e9SAndroid Build Coastguard Worker ((S, S, S), [1], (0, 1, 2), RuntimeError, r"If dim is specified, it must be of length 1 or 2."), 1602*da0073e9SAndroid Build Coastguard Worker ((S, S, S), [1], None, RuntimeError, r"If dim is not specified but ord is, the input must be 1D or 2D"), 1603*da0073e9SAndroid Build Coastguard Worker ] 1604*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 1605*da0073e9SAndroid Build Coastguard Worker for input_size, ord_settings, dim, error_type, error_regex in error_test_cases: 1606*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1607*da0073e9SAndroid Build Coastguard Worker for ord in ord_settings: 1608*da0073e9SAndroid Build Coastguard Worker run_error_test_case(input, ord, dim, keepdim, error_type, error_regex) 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Worker # Test complex number inputs for linalg.norm 1611*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1612*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1613*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cfloat, torch.cdouble) 1614*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.cfloat: 5e-4}) 1615*da0073e9SAndroid Build Coastguard Worker def test_norm_complex(self, device, dtype): 1616*da0073e9SAndroid Build Coastguard Worker def gen_error_message(input_size, ord, keepdim, dim=None): 1617*da0073e9SAndroid Build Coastguard Worker return f"complex norm failed for input size {input_size}, ord={ord}, keepdim={keepdim}, dim={dim}" 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf] 1620*da0073e9SAndroid Build Coastguard Worker matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf] 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker # Test supported ords 1623*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 1624*da0073e9SAndroid Build Coastguard Worker # vector norm 1625*da0073e9SAndroid Build Coastguard Worker x = torch.randn(25, device=device, dtype=dtype) 1626*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 1627*da0073e9SAndroid Build Coastguard Worker for ord in vector_ords: 1628*da0073e9SAndroid Build Coastguard Worker res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() 1629*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, ord, keepdims=keepdim) 1630*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), ord, keepdim) 1631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 1632*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg, exact_dtype=False) 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker res_out = torch.tensor([], device=device, dtype=res.dtype) 1635*da0073e9SAndroid Build Coastguard Worker torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) 1636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_out.shape, expected.shape, msg=msg) 1637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_out, expected, msg=msg) 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker # matrix norm 1640*da0073e9SAndroid Build Coastguard Worker x = torch.randn(25, 25, device=device, dtype=dtype) 1641*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 1642*da0073e9SAndroid Build Coastguard Worker for ord in matrix_ords: 1643*da0073e9SAndroid Build Coastguard Worker res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() 1644*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, ord, keepdims=keepdim) 1645*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), ord, keepdim) 1646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 1647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg, exact_dtype=False) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker res_out = torch.tensor([], device=device, dtype=res.dtype) 1650*da0073e9SAndroid Build Coastguard Worker torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) 1651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_out.shape, expected.shape, msg=msg) 1652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_out, expected, msg=msg) 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker # Test that linal.vector_norm gives the same result as numpy when inputs 1655*da0073e9SAndroid Build Coastguard Worker # contain extreme values (inf, -inf, nan) 1656*da0073e9SAndroid Build Coastguard Worker def test_vector_norm_extreme_values(self, device): 1657*da0073e9SAndroid Build Coastguard Worker vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf] 1658*da0073e9SAndroid Build Coastguard Worker vectors = [] 1659*da0073e9SAndroid Build Coastguard Worker for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2): 1660*da0073e9SAndroid Build Coastguard Worker vectors.append(list(pair)) 1661*da0073e9SAndroid Build Coastguard Worker for vector in vectors: 1662*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(vector, device=device) 1663*da0073e9SAndroid Build Coastguard Worker x_n = x.cpu().numpy() 1664*da0073e9SAndroid Build Coastguard Worker for ord in vector_ords: 1665*da0073e9SAndroid Build Coastguard Worker msg = f'ord={ord}, vector={vector}' 1666*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.vector_norm(x, ord=ord) 1667*da0073e9SAndroid Build Coastguard Worker result_n = np.linalg.norm(x_n, ord=ord) 1668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_n, msg=msg) 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 1671*da0073e9SAndroid Build Coastguard Worker def test_vector_norm_reduce_over_1D_vector(self, device, dtype): 1672*da0073e9SAndroid Build Coastguard Worker input_sizes_and_dims = [ 1673*da0073e9SAndroid Build Coastguard Worker ((6, 1), -1), 1674*da0073e9SAndroid Build Coastguard Worker ((3, 1, 2, 1), (1, 3)), 1675*da0073e9SAndroid Build Coastguard Worker ((1,), None), 1676*da0073e9SAndroid Build Coastguard Worker ] 1677*da0073e9SAndroid Build Coastguard Worker orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2] 1678*da0073e9SAndroid Build Coastguard Worker keepdims = [True, False] 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims): 1681*da0073e9SAndroid Build Coastguard Worker input_size = input_size_and_dim[0] 1682*da0073e9SAndroid Build Coastguard Worker dim = input_size_and_dim[1] 1683*da0073e9SAndroid Build Coastguard Worker if type(dim) is tuple and ord == 0: 1684*da0073e9SAndroid Build Coastguard Worker # skip because np.linalg.norm raises 'ValueError: Invalid norm order for matrices.' 1685*da0073e9SAndroid Build Coastguard Worker continue 1686*da0073e9SAndroid Build Coastguard Worker input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1687*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.vector_norm(input, ord, dim, keepdim) 1688*da0073e9SAndroid Build Coastguard Worker result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim) 1689*da0073e9SAndroid Build Coastguard Worker 1690*da0073e9SAndroid Build Coastguard Worker msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_numpy, msg=msg) 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 1694*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1695*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1696*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 2e-5}) 1697*da0073e9SAndroid Build Coastguard Worker def test_matrix_norm(self, device, dtype): 1698*da0073e9SAndroid Build Coastguard Worker # Test only inputs for which torch.linalg.matrix_norm diverges from torch.linalg.norm 1699*da0073e9SAndroid Build Coastguard Worker A = make_tensor((2, 2, 2), dtype=dtype, device=device) 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must have at least 2 dimensions.*'): 1702*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_norm(make_tensor((2,), dtype=dtype, device=device)) 1703*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a 2-tuple.*'): 1704*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_norm(A, dim=(0,)) 1705*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'): 1706*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_norm(A, ord=0) 1707*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'): 1708*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_norm(A, ord=3.0) 1709*da0073e9SAndroid Build Coastguard Worker 1710*da0073e9SAndroid Build Coastguard Worker # Test dim=None behavior 1711*da0073e9SAndroid Build Coastguard Worker ref = torch.linalg.norm(A, dim=(-2, -1)) 1712*da0073e9SAndroid Build Coastguard Worker res = torch.linalg.matrix_norm(A) 1713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1714*da0073e9SAndroid Build Coastguard Worker 1715*da0073e9SAndroid Build Coastguard Worker # Test that linal.norm gives the same result as numpy when inputs 1716*da0073e9SAndroid Build Coastguard Worker # contain extreme values (inf, -inf, nan) 1717*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 1718*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipped on MacOS!") 1719*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1720*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1721*da0073e9SAndroid Build Coastguard Worker def test_norm_extreme_values(self, device): 1722*da0073e9SAndroid Build Coastguard Worker vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf] 1723*da0073e9SAndroid Build Coastguard Worker # matrix_ords 'nuc', 2, -2 are skipped currently 1724*da0073e9SAndroid Build Coastguard Worker # See issue https://github.com/pytorch/pytorch/issues/71911 1725*da0073e9SAndroid Build Coastguard Worker matrix_ords = ['fro', 1, inf, -1, -inf] 1726*da0073e9SAndroid Build Coastguard Worker vectors = [] 1727*da0073e9SAndroid Build Coastguard Worker matrices = [] 1728*da0073e9SAndroid Build Coastguard Worker for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2): 1729*da0073e9SAndroid Build Coastguard Worker vectors.append(list(pair)) 1730*da0073e9SAndroid Build Coastguard Worker matrices.append([[pair[0], pair[1]]]) 1731*da0073e9SAndroid Build Coastguard Worker matrices.append([[pair[0]], [pair[1]]]) 1732*da0073e9SAndroid Build Coastguard Worker for vector in vectors: 1733*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(vector).to(device) 1734*da0073e9SAndroid Build Coastguard Worker x_n = x.cpu().numpy() 1735*da0073e9SAndroid Build Coastguard Worker for ord in vector_ords: 1736*da0073e9SAndroid Build Coastguard Worker msg = f'ord={ord}, vector={vector}' 1737*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(x, ord=ord) 1738*da0073e9SAndroid Build Coastguard Worker result_n = np.linalg.norm(x_n, ord=ord) 1739*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_n, msg=msg) 1740*da0073e9SAndroid Build Coastguard Worker 1741*da0073e9SAndroid Build Coastguard Worker # TODO: Remove this function once the broken cases are fixed 1742*da0073e9SAndroid Build Coastguard Worker def is_broken_matrix_norm_case(ord, x): 1743*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 1744*da0073e9SAndroid Build Coastguard Worker if x.size() == torch.Size([1, 2]): 1745*da0073e9SAndroid Build Coastguard Worker if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1: 1746*da0073e9SAndroid Build Coastguard Worker # These cases are broken because of an issue with svd 1747*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/43567 1748*da0073e9SAndroid Build Coastguard Worker return True 1749*da0073e9SAndroid Build Coastguard Worker if ord in ['nuc', 2, -2]: 1750*da0073e9SAndroid Build Coastguard Worker # These cases are broken because of another issue with svd 1751*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/52633 1752*da0073e9SAndroid Build Coastguard Worker return True 1753*da0073e9SAndroid Build Coastguard Worker return False 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker for matrix in matrices: 1756*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(matrix).to(device) 1757*da0073e9SAndroid Build Coastguard Worker x_n = x.cpu().numpy() 1758*da0073e9SAndroid Build Coastguard Worker for ord in matrix_ords: 1759*da0073e9SAndroid Build Coastguard Worker msg = f'ord={ord}, matrix={matrix}' 1760*da0073e9SAndroid Build Coastguard Worker if is_broken_matrix_norm_case(ord, x): 1761*da0073e9SAndroid Build Coastguard Worker continue 1762*da0073e9SAndroid Build Coastguard Worker else: 1763*da0073e9SAndroid Build Coastguard Worker result_n = np.linalg.norm(x_n, ord=ord) 1764*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(x, ord=ord) 1765*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_n, msg=msg) 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker # Test degenerate shape results match numpy for linalg.norm vector norms 1768*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1769*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1770*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 1771*da0073e9SAndroid Build Coastguard Worker def test_norm_vector_degenerate_shapes(self, device, dtype): 1772*da0073e9SAndroid Build Coastguard Worker def run_test_case(input, ord, dim, keepdim): 1773*da0073e9SAndroid Build Coastguard Worker msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1774*da0073e9SAndroid Build Coastguard Worker if (input.numel() == 0 and 1775*da0073e9SAndroid Build Coastguard Worker (ord < 0. or ord == inf) and 1776*da0073e9SAndroid Build Coastguard Worker (dim is None or input.shape[dim] == 0)): 1777*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1778*da0073e9SAndroid Build Coastguard Worker torch.linalg.norm(input, ord, dim, keepdim) 1779*da0073e9SAndroid Build Coastguard Worker else: 1780*da0073e9SAndroid Build Coastguard Worker input_numpy = input.cpu().numpy() 1781*da0073e9SAndroid Build Coastguard Worker result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1782*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(input, ord, dim, keepdim) 1783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_numpy, msg=msg) 1784*da0073e9SAndroid Build Coastguard Worker 1785*da0073e9SAndroid Build Coastguard Worker ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf] 1786*da0073e9SAndroid Build Coastguard Worker S = 10 1787*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1788*da0073e9SAndroid Build Coastguard Worker # input size, dim 1789*da0073e9SAndroid Build Coastguard Worker ((0, ), None), 1790*da0073e9SAndroid Build Coastguard Worker ((0, S), 0), 1791*da0073e9SAndroid Build Coastguard Worker ((0, S), 1), 1792*da0073e9SAndroid Build Coastguard Worker ((S, 0), 0), 1793*da0073e9SAndroid Build Coastguard Worker ((S, 0), 1), 1794*da0073e9SAndroid Build Coastguard Worker ] 1795*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 1796*da0073e9SAndroid Build Coastguard Worker for input_size, dim in test_cases: 1797*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1798*da0073e9SAndroid Build Coastguard Worker for ord in ord_vector: 1799*da0073e9SAndroid Build Coastguard Worker run_test_case(input, ord, dim, keepdim) 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker # Test degenerate shape results match numpy for linalg.norm matrix norms 1802*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1803*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1804*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 1805*da0073e9SAndroid Build Coastguard Worker def test_norm_matrix_degenerate_shapes(self, device, dtype): 1806*da0073e9SAndroid Build Coastguard Worker def run_test_case(input, ord, dim, keepdim, should_error): 1807*da0073e9SAndroid Build Coastguard Worker msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' 1808*da0073e9SAndroid Build Coastguard Worker input_numpy = input.cpu().numpy() 1809*da0073e9SAndroid Build Coastguard Worker ops = [torch.linalg.norm] 1810*da0073e9SAndroid Build Coastguard Worker 1811*da0073e9SAndroid Build Coastguard Worker if ord is not None and dim is not None: 1812*da0073e9SAndroid Build Coastguard Worker ops.append(torch.linalg.matrix_norm) 1813*da0073e9SAndroid Build Coastguard Worker 1814*da0073e9SAndroid Build Coastguard Worker if should_error: 1815*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1816*da0073e9SAndroid Build Coastguard Worker np.linalg.norm(input_numpy, ord, dim, keepdim) 1817*da0073e9SAndroid Build Coastguard Worker for op in ops: 1818*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 1819*da0073e9SAndroid Build Coastguard Worker op(input, ord, dim, keepdim) 1820*da0073e9SAndroid Build Coastguard Worker else: 1821*da0073e9SAndroid Build Coastguard Worker result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) 1822*da0073e9SAndroid Build Coastguard Worker for op in ops: 1823*da0073e9SAndroid Build Coastguard Worker result = op(input, ord, dim, keepdim) 1824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_numpy, msg=msg) 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None] 1827*da0073e9SAndroid Build Coastguard Worker S = 10 1828*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1829*da0073e9SAndroid Build Coastguard Worker # input size, p settings that cause error, dim 1830*da0073e9SAndroid Build Coastguard Worker ((0, 0), [1, 2, inf, -1, -2, -inf], None), 1831*da0073e9SAndroid Build Coastguard Worker ((0, S), [2, inf, -2, -inf], None), 1832*da0073e9SAndroid Build Coastguard Worker ((S, 0), [1, 2, -1, -2], None), 1833*da0073e9SAndroid Build Coastguard Worker ((S, S, 0), [], (0, 1)), 1834*da0073e9SAndroid Build Coastguard Worker ((1, S, 0), [], (0, 1)), 1835*da0073e9SAndroid Build Coastguard Worker ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)), 1836*da0073e9SAndroid Build Coastguard Worker ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)), 1837*da0073e9SAndroid Build Coastguard Worker ] 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 1840*da0073e9SAndroid Build Coastguard Worker for input_size, error_ords, dim in test_cases: 1841*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_size, dtype=dtype, device=device) 1842*da0073e9SAndroid Build Coastguard Worker for ord in ord_matrix: 1843*da0073e9SAndroid Build Coastguard Worker run_test_case(input, ord, dim, keepdim, ord in error_ords) 1844*da0073e9SAndroid Build Coastguard Worker 1845*da0073e9SAndroid Build Coastguard Worker def test_norm_fastpaths(self, device): 1846*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 5, device=device) 1847*da0073e9SAndroid Build Coastguard Worker 1848*da0073e9SAndroid Build Coastguard Worker # slow path 1849*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(x, 4.5, 1) 1850*da0073e9SAndroid Build Coastguard Worker expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5) 1851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1852*da0073e9SAndroid Build Coastguard Worker 1853*da0073e9SAndroid Build Coastguard Worker # fast 0-norm 1854*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(x, 0, 1) 1855*da0073e9SAndroid Build Coastguard Worker expected = (x != 0).type_as(x).sum(1) 1856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1857*da0073e9SAndroid Build Coastguard Worker 1858*da0073e9SAndroid Build Coastguard Worker # fast 1-norm 1859*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(x, 1, 1) 1860*da0073e9SAndroid Build Coastguard Worker expected = x.abs().sum(1) 1861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker # fast 2-norm 1864*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(x, 2, 1) 1865*da0073e9SAndroid Build Coastguard Worker expected = torch.sqrt(x.pow(2).sum(1)) 1866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1867*da0073e9SAndroid Build Coastguard Worker 1868*da0073e9SAndroid Build Coastguard Worker # fast 3-norm 1869*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.norm(x, 3, 1) 1870*da0073e9SAndroid Build Coastguard Worker expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) 1871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1872*da0073e9SAndroid Build Coastguard Worker 1873*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1874*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1875*da0073e9SAndroid Build Coastguard Worker # NumPy computes only in float64 and complex128 precisions 1876*da0073e9SAndroid Build Coastguard Worker # for float32 or complex64 results might be very different from float64 or complex128 1877*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.complex128) 1878*da0073e9SAndroid Build Coastguard Worker def test_eig_numpy(self, device, dtype): 1879*da0073e9SAndroid Build Coastguard Worker def run_test(shape, *, symmetric=False): 1880*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_symmetric_matrix 1881*da0073e9SAndroid Build Coastguard Worker 1882*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex and symmetric: 1883*da0073e9SAndroid Build Coastguard Worker # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 1884*da0073e9SAndroid Build Coastguard Worker # unlike NumPy the result is not cast to float32 or float64 dtype in this case 1885*da0073e9SAndroid Build Coastguard Worker a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 1886*da0073e9SAndroid Build Coastguard Worker else: 1887*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, dtype=dtype, device=device) 1888*da0073e9SAndroid Build Coastguard Worker 1889*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.eig(a) 1890*da0073e9SAndroid Build Coastguard Worker 1891*da0073e9SAndroid Build Coastguard Worker # compare with NumPy 1892*da0073e9SAndroid Build Coastguard Worker # the eigenvalues are not necessarily ordered 1893*da0073e9SAndroid Build Coastguard Worker # so order of NumPy and PyTorch can be different 1894*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.eig(a.cpu().numpy()) 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker # sort NumPy output 1897*da0073e9SAndroid Build Coastguard Worker ind = np.argsort(expected[0], axis=-1)[::-1] 1898*da0073e9SAndroid Build Coastguard Worker expected = (np.take_along_axis(expected[0], ind, axis=-1), np.take_along_axis(expected[1], ind[:, None], axis=-1)) 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker # sort PyTorch output 1901*da0073e9SAndroid Build Coastguard Worker # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead 1902*da0073e9SAndroid Build Coastguard Worker # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble 1903*da0073e9SAndroid Build Coastguard Worker # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble' 1904*da0073e9SAndroid Build Coastguard Worker ind = np.argsort(actual[0].cpu().numpy(), axis=-1)[::-1] 1905*da0073e9SAndroid Build Coastguard Worker actual_np = [x.cpu().numpy() for x in actual] 1906*da0073e9SAndroid Build Coastguard Worker sorted_actual = ( 1907*da0073e9SAndroid Build Coastguard Worker np.take_along_axis(actual_np[0], ind, axis=-1), 1908*da0073e9SAndroid Build Coastguard Worker np.take_along_axis(actual_np[1], ind[:, None], axis=-1)) 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected[0], sorted_actual[0], exact_dtype=False) 1911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(abs(expected[1]), abs(sorted_actual[1]), exact_dtype=False) 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker shapes = [(0, 0), # Empty matrix 1914*da0073e9SAndroid Build Coastguard Worker (5, 5), # Single matrix 1915*da0073e9SAndroid Build Coastguard Worker (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 1916*da0073e9SAndroid Build Coastguard Worker (2, 5, 5), # 3-dim tensors 1917*da0073e9SAndroid Build Coastguard Worker (2, 1, 5, 5)] # 4-dim tensors 1918*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 1919*da0073e9SAndroid Build Coastguard Worker run_test(shape) 1920*da0073e9SAndroid Build Coastguard Worker run_test(shape, symmetric=True) 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1923*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1924*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1925*da0073e9SAndroid Build Coastguard Worker def test_eig_compare_backends(self, device, dtype): 1926*da0073e9SAndroid Build Coastguard Worker def run_test(shape, *, symmetric=False): 1927*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_symmetric_matrix 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex and symmetric: 1930*da0073e9SAndroid Build Coastguard Worker # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 1931*da0073e9SAndroid Build Coastguard Worker a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 1932*da0073e9SAndroid Build Coastguard Worker else: 1933*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, dtype=dtype, device=device) 1934*da0073e9SAndroid Build Coastguard Worker 1935*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.eig(a) 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker complementary_device = 'cpu' 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker # compare with CPU 1940*da0073e9SAndroid Build Coastguard Worker expected = torch.linalg.eig(a.to(complementary_device)) 1941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected[0], actual[0]) 1942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected[1], actual[1]) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker shapes = [(0, 0), # Empty matrix 1945*da0073e9SAndroid Build Coastguard Worker (5, 5), # Single matrix 1946*da0073e9SAndroid Build Coastguard Worker (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 1947*da0073e9SAndroid Build Coastguard Worker (2, 5, 5), # 3-dim tensors 1948*da0073e9SAndroid Build Coastguard Worker (2, 1, 5, 5)] # 4-dim tensors 1949*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 1950*da0073e9SAndroid Build Coastguard Worker run_test(shape) 1951*da0073e9SAndroid Build Coastguard Worker run_test(shape, symmetric=True) 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker @slowTest 1954*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1955*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1956*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 1957*da0073e9SAndroid Build Coastguard Worker def test_eig_check_magma(self, device, dtype): 1958*da0073e9SAndroid Build Coastguard Worker # For CUDA inputs only matrices of size larger than 2048x2048 actually call MAGMA library 1959*da0073e9SAndroid Build Coastguard Worker shape = (2049, 2049) 1960*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, dtype=dtype, device=device) 1961*da0073e9SAndroid Build Coastguard Worker w, v = torch.linalg.eig(a) 1962*da0073e9SAndroid Build Coastguard Worker # check correctness using eigendecomposition identity 1963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3) 1964*da0073e9SAndroid Build Coastguard Worker 1965*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 1966*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 1967*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 1968*da0073e9SAndroid Build Coastguard Worker def test_eig_errors_and_warnings(self, device, dtype): 1969*da0073e9SAndroid Build Coastguard Worker # eig requires the input to be at least 2 dimensional tensor 1970*da0073e9SAndroid Build Coastguard Worker a = make_tensor(2, dtype=dtype, device=device) 1971*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 1972*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a) 1973*da0073e9SAndroid Build Coastguard Worker 1974*da0073e9SAndroid Build Coastguard Worker # eig requires a square matrix 1975*da0073e9SAndroid Build Coastguard Worker a = make_tensor((2, 3), dtype=dtype, device=device) 1976*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 1977*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a) 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker # if out tensor with floating dtype is passed for complex output an error is thrown 1980*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex: 1981*da0073e9SAndroid Build Coastguard Worker # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i 1982*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device) 1983*da0073e9SAndroid Build Coastguard Worker out0 = torch.empty(0, device=device, dtype=dtype) 1984*da0073e9SAndroid Build Coastguard Worker out1 = torch.empty(0, device=device, dtype=dtype) 1985*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"): 1986*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a, out=(out0, out1)) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker out0 = torch.empty(0, device=device, dtype=torch.complex128) 1989*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected eigenvectors to be safely castable"): 1990*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a, out=(out0, out1)) 1991*da0073e9SAndroid Build Coastguard Worker 1992*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 1993*da0073e9SAndroid Build Coastguard Worker a = make_tensor((3, 3), dtype=dtype, device=device) 1994*da0073e9SAndroid Build Coastguard Worker out0 = torch.empty(0, dtype=torch.int, device=device) 1995*da0073e9SAndroid Build Coastguard Worker out1 = torch.empty(0, dtype=torch.int, device=device) 1996*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"): 1997*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a, out=(out0, out1)) 1998*da0073e9SAndroid Build Coastguard Worker 1999*da0073e9SAndroid Build Coastguard Worker out0 = torch.empty(0, dtype=torch.complex128, device=device) 2000*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got eigenvectors with dtype Int"): 2001*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a, out=(out0, out1)) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 2004*da0073e9SAndroid Build Coastguard Worker a = make_tensor((3, 3), dtype=dtype, device=device) 2005*da0073e9SAndroid Build Coastguard Worker out0 = torch.empty(1, device=device, dtype=torch.complex128) 2006*da0073e9SAndroid Build Coastguard Worker out1 = torch.empty(1, device=device, dtype=torch.complex128) 2007*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2008*da0073e9SAndroid Build Coastguard Worker # Trigger warning 2009*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a, out=(out0, out1)) 2010*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 2011*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 2012*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2013*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-2].message)) 2014*da0073e9SAndroid Build Coastguard Worker 2015*da0073e9SAndroid Build Coastguard Worker # device should match 2016*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 2017*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2018*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128) 2019*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty(0, device=device, dtype=torch.complex128) 2020*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2021*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a, out=(out_w, out_v)) 2022*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(0, device=device, dtype=torch.complex128) 2023*da0073e9SAndroid Build Coastguard Worker out_v = torch.empty(0, device=wrong_device, dtype=torch.complex128) 2024*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2025*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a, out=(out_w, out_v)) 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2028*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2029*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2030*da0073e9SAndroid Build Coastguard Worker def test_eig_with_nan(self, device, dtype): 2031*da0073e9SAndroid Build Coastguard Worker for val in [np.inf, np.nan]: 2032*da0073e9SAndroid Build Coastguard Worker for batch_dim in [(), (10,)]: 2033*da0073e9SAndroid Build Coastguard Worker a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype) 2034*da0073e9SAndroid Build Coastguard Worker a[..., -1, -1] = val 2035*da0073e9SAndroid Build Coastguard Worker 2036*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"): 2037*da0073e9SAndroid Build Coastguard Worker torch.linalg.eig(a) 2038*da0073e9SAndroid Build Coastguard Worker 2039*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2040*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2041*da0073e9SAndroid Build Coastguard Worker # NumPy computes only in float64 and complex128 precisions 2042*da0073e9SAndroid Build Coastguard Worker # for float32 or complex64 results might be very different from float64 or complex128 2043*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.complex128) 2044*da0073e9SAndroid Build Coastguard Worker def test_eigvals_numpy(self, device, dtype): 2045*da0073e9SAndroid Build Coastguard Worker def run_test(shape, *, symmetric=False): 2046*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_symmetric_matrix 2047*da0073e9SAndroid Build Coastguard Worker 2048*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex and symmetric: 2049*da0073e9SAndroid Build Coastguard Worker # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 2050*da0073e9SAndroid Build Coastguard Worker # unlike NumPy the result is not cast to float32 or float64 dtype in this case 2051*da0073e9SAndroid Build Coastguard Worker a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 2052*da0073e9SAndroid Build Coastguard Worker else: 2053*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, dtype=dtype, device=device) 2054*da0073e9SAndroid Build Coastguard Worker 2055*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.eigvals(a) 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Worker # compare with NumPy 2058*da0073e9SAndroid Build Coastguard Worker # the eigenvalues are not necessarily ordered 2059*da0073e9SAndroid Build Coastguard Worker # so order of NumPy and PyTorch can be different 2060*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.eigvals(a.cpu().numpy()) 2061*da0073e9SAndroid Build Coastguard Worker 2062*da0073e9SAndroid Build Coastguard Worker # sort NumPy output 2063*da0073e9SAndroid Build Coastguard Worker ind = np.argsort(expected, axis=-1)[::-1] 2064*da0073e9SAndroid Build Coastguard Worker expected = np.take_along_axis(expected, ind, axis=-1) 2065*da0073e9SAndroid Build Coastguard Worker 2066*da0073e9SAndroid Build Coastguard Worker # sort PyTorch output 2067*da0073e9SAndroid Build Coastguard Worker # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead 2068*da0073e9SAndroid Build Coastguard Worker # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble 2069*da0073e9SAndroid Build Coastguard Worker # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble' 2070*da0073e9SAndroid Build Coastguard Worker ind = np.argsort(actual.cpu().numpy(), axis=-1)[::-1] 2071*da0073e9SAndroid Build Coastguard Worker actual_np = actual.cpu().numpy() 2072*da0073e9SAndroid Build Coastguard Worker sorted_actual = np.take_along_axis(actual_np, ind, axis=-1) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, sorted_actual, exact_dtype=False) 2075*da0073e9SAndroid Build Coastguard Worker 2076*da0073e9SAndroid Build Coastguard Worker shapes = [(0, 0), # Empty matrix 2077*da0073e9SAndroid Build Coastguard Worker (5, 5), # Single matrix 2078*da0073e9SAndroid Build Coastguard Worker (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 2079*da0073e9SAndroid Build Coastguard Worker (2, 5, 5), # 3-dim tensors 2080*da0073e9SAndroid Build Coastguard Worker (2, 1, 5, 5)] # 4-dim tensors 2081*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2082*da0073e9SAndroid Build Coastguard Worker run_test(shape) 2083*da0073e9SAndroid Build Coastguard Worker run_test(shape, symmetric=True) 2084*da0073e9SAndroid Build Coastguard Worker 2085*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2086*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2087*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2088*da0073e9SAndroid Build Coastguard Worker def test_eigvals_compare_backends(self, device, dtype): 2089*da0073e9SAndroid Build Coastguard Worker def run_test(shape, *, symmetric=False): 2090*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_symmetric_matrix 2091*da0073e9SAndroid Build Coastguard Worker 2092*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex and symmetric: 2093*da0073e9SAndroid Build Coastguard Worker # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero 2094*da0073e9SAndroid Build Coastguard Worker a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device) 2095*da0073e9SAndroid Build Coastguard Worker else: 2096*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, dtype=dtype, device=device) 2097*da0073e9SAndroid Build Coastguard Worker 2098*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.eigvals(a) 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker complementary_device = 'cpu' 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker # compare with CPU 2103*da0073e9SAndroid Build Coastguard Worker expected = torch.linalg.eigvals(a.to(complementary_device)) 2104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2105*da0073e9SAndroid Build Coastguard Worker 2106*da0073e9SAndroid Build Coastguard Worker # check out= variant 2107*da0073e9SAndroid Build Coastguard Worker complex_dtype = dtype 2108*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex: 2109*da0073e9SAndroid Build Coastguard Worker complex_dtype = torch.complex128 if dtype == torch.float64 else torch.complex64 2110*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=complex_dtype, device=device) 2111*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.eigvals(a, out=out) 2112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 2113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.to(complex_dtype), out) 2114*da0073e9SAndroid Build Coastguard Worker 2115*da0073e9SAndroid Build Coastguard Worker # check non-contiguous out 2116*da0073e9SAndroid Build Coastguard Worker if a.numel() > 0: 2117*da0073e9SAndroid Build Coastguard Worker out = torch.empty(2 * shape[0], *shape[1:-1], dtype=complex_dtype, device=device)[::2] 2118*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.is_contiguous()) 2119*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.eigvals(a, out=out) 2120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 2121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.to(complex_dtype), out) 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker shapes = [(0, 0), # Empty matrix 2124*da0073e9SAndroid Build Coastguard Worker (5, 5), # Single matrix 2125*da0073e9SAndroid Build Coastguard Worker (0, 0, 0), (0, 5, 5), # Zero batch dimension tensors 2126*da0073e9SAndroid Build Coastguard Worker (2, 5, 5), # 3-dim tensors 2127*da0073e9SAndroid Build Coastguard Worker (2, 1, 5, 5)] # 4-dim tensors 2128*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2129*da0073e9SAndroid Build Coastguard Worker run_test(shape) 2130*da0073e9SAndroid Build Coastguard Worker run_test(shape, symmetric=True) 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2133*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2134*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2135*da0073e9SAndroid Build Coastguard Worker def test_eigvals_errors_and_warnings(self, device, dtype): 2136*da0073e9SAndroid Build Coastguard Worker # eig requires the input to be at least 2 dimensional tensor 2137*da0073e9SAndroid Build Coastguard Worker a = make_tensor(2, dtype=dtype, device=device) 2138*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 2139*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvals(a) 2140*da0073e9SAndroid Build Coastguard Worker 2141*da0073e9SAndroid Build Coastguard Worker # eig requires a square matrix 2142*da0073e9SAndroid Build Coastguard Worker a = make_tensor((2, 3), dtype=dtype, device=device) 2143*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 2144*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvals(a) 2145*da0073e9SAndroid Build Coastguard Worker 2146*da0073e9SAndroid Build Coastguard Worker # if out tensor with floating dtype is passed for complex output an error is thrown 2147*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex: 2148*da0073e9SAndroid Build Coastguard Worker # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i 2149*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device) 2150*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=device, dtype=dtype) 2151*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"): 2152*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvals(a, out=out) 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 2155*da0073e9SAndroid Build Coastguard Worker a = make_tensor((3, 3), dtype=dtype, device=device) 2156*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.int, device=device) 2157*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"): 2158*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvals(a, out=out) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 2161*da0073e9SAndroid Build Coastguard Worker out = torch.empty(1, device=device, dtype=torch.complex128) 2162*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2163*da0073e9SAndroid Build Coastguard Worker # Trigger warning 2164*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvals(a, out=out) 2165*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 2166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 2167*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2168*da0073e9SAndroid Build Coastguard Worker 2169*da0073e9SAndroid Build Coastguard Worker # device should match 2170*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 2171*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2172*da0073e9SAndroid Build Coastguard Worker out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128) 2173*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2174*da0073e9SAndroid Build Coastguard Worker torch.linalg.eigvals(a, out=out_w) 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2177*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2178*da0073e9SAndroid Build Coastguard Worker def test_norm_old(self, device): 2179*da0073e9SAndroid Build Coastguard Worker def gen_error_message(input_size, p, keepdim, dim=None): 2180*da0073e9SAndroid Build Coastguard Worker return f"norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}" 2181*da0073e9SAndroid Build Coastguard Worker 2182*da0073e9SAndroid Build Coastguard Worker # 'nuc' norm uses SVD, and thus its precsion is much lower than other norms. 2183*da0073e9SAndroid Build Coastguard Worker # test_svd takes @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4}), 2184*da0073e9SAndroid Build Coastguard Worker # and here we are doing the same thing for nuc norm. 2185*da0073e9SAndroid Build Coastguard Worker class PrecisionContext: 2186*da0073e9SAndroid Build Coastguard Worker def __init__(self, test, norm): 2187*da0073e9SAndroid Build Coastguard Worker self.norm = norm 2188*da0073e9SAndroid Build Coastguard Worker self.saved_overrides = getattr(test, 'precision_overrides', None) 2189*da0073e9SAndroid Build Coastguard Worker self.target_test = test 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 2192*da0073e9SAndroid Build Coastguard Worker if 'nuc' != self.norm: 2193*da0073e9SAndroid Build Coastguard Worker return None 2194*da0073e9SAndroid Build Coastguard Worker self.target_test.precision_overrides = {torch.float: 1e-4, torch.cfloat: 2e-4} 2195*da0073e9SAndroid Build Coastguard Worker return self.target_test.precision_overrides 2196*da0073e9SAndroid Build Coastguard Worker 2197*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type, value, tb) -> bool: 2198*da0073e9SAndroid Build Coastguard Worker if 'nuc' != self.norm: 2199*da0073e9SAndroid Build Coastguard Worker return True 2200*da0073e9SAndroid Build Coastguard Worker if self.saved_overrides is None: 2201*da0073e9SAndroid Build Coastguard Worker delattr(self.target_test, 'precision_overrides') 2202*da0073e9SAndroid Build Coastguard Worker else: 2203*da0073e9SAndroid Build Coastguard Worker self.target_test.precision_overrides = self.saved_overrides 2204*da0073e9SAndroid Build Coastguard Worker return True 2205*da0073e9SAndroid Build Coastguard Worker 2206*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 2207*da0073e9SAndroid Build Coastguard Worker # full reduction 2208*da0073e9SAndroid Build Coastguard Worker x = torch.randn(25, device=device) 2209*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 2210*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]: 2211*da0073e9SAndroid Build Coastguard Worker res = x.norm(p, keepdim=keepdim).cpu() 2212*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, p, keepdims=keepdim) 2213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim)) 2214*da0073e9SAndroid Build Coastguard Worker 2215*da0073e9SAndroid Build Coastguard Worker # one dimension 2216*da0073e9SAndroid Build Coastguard Worker x = torch.randn(25, 25, device=device) 2217*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 2218*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]: 2219*da0073e9SAndroid Build Coastguard Worker dim = 1 2220*da0073e9SAndroid Build Coastguard Worker res = x.norm(p, dim, keepdim=keepdim).cpu() 2221*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, p, dim, keepdims=keepdim) 2222*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), p, keepdim, dim) 2223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 2224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg) 2225*da0073e9SAndroid Build Coastguard Worker 2226*da0073e9SAndroid Build Coastguard Worker # matrix norm 2227*da0073e9SAndroid Build Coastguard Worker for p in ['fro', 'nuc']: 2228*da0073e9SAndroid Build Coastguard Worker res = x.norm(p, keepdim=keepdim).cpu() 2229*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, p, keepdims=keepdim) 2230*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), p, keepdim) 2231*da0073e9SAndroid Build Coastguard Worker with PrecisionContext(self, p): 2232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 2233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg) 2234*da0073e9SAndroid Build Coastguard Worker 2235*da0073e9SAndroid Build Coastguard Worker # zero dimensions 2236*da0073e9SAndroid Build Coastguard Worker x = torch.randn((), device=device) 2237*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 2238*da0073e9SAndroid Build Coastguard Worker res = x.norm(keepdim=keepdim).cpu() 2239*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, keepdims=keepdim) 2240*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), None, keepdim) 2241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 2242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg) 2243*da0073e9SAndroid Build Coastguard Worker 2244*da0073e9SAndroid Build Coastguard Worker # larger tensor sanity check 2245*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2246*da0073e9SAndroid Build Coastguard Worker 2 * torch.norm(torch.ones(10000), keepdim=keepdim), 2247*da0073e9SAndroid Build Coastguard Worker torch.norm(torch.ones(40000), keepdim=keepdim)) 2248*da0073e9SAndroid Build Coastguard Worker 2249*da0073e9SAndroid Build Coastguard Worker # matrix norm with non-square >2-D tensors, all combinations of reduction dims 2250*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 6, 7, 8, device=device) 2251*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 2252*da0073e9SAndroid Build Coastguard Worker for p in ['fro', 'nuc']: 2253*da0073e9SAndroid Build Coastguard Worker for dim in itertools.product(*[list(range(4))] * 2): 2254*da0073e9SAndroid Build Coastguard Worker if dim[0] == dim[1]: 2255*da0073e9SAndroid Build Coastguard Worker continue 2256*da0073e9SAndroid Build Coastguard Worker res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu() 2257*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim) 2258*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), p, keepdim, dim) 2259*da0073e9SAndroid Build Coastguard Worker with PrecisionContext(self, p): 2260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 2261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg) 2262*da0073e9SAndroid Build Coastguard Worker 2263*da0073e9SAndroid Build Coastguard Worker # Test that torch.norm with p=+/-inf propagates NaN 2264*da0073e9SAndroid Build Coastguard Worker def test_norm_old_nan_propagation(self, device): 2265*da0073e9SAndroid Build Coastguard Worker ords = [inf, -inf] 2266*da0073e9SAndroid Build Coastguard Worker for pair in itertools.product([0.0, nan, 1.0], repeat=2): 2267*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(list(pair), device=device) 2268*da0073e9SAndroid Build Coastguard Worker for ord in ords: 2269*da0073e9SAndroid Build Coastguard Worker result = torch.norm(x, p=ord) 2270*da0073e9SAndroid Build Coastguard Worker result_check = torch.linalg.norm(x, ord=ord) 2271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_check) 2272*da0073e9SAndroid Build Coastguard Worker 2273*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2274*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2275*da0073e9SAndroid Build Coastguard Worker def test_norm_complex_old(self, device): 2276*da0073e9SAndroid Build Coastguard Worker def gen_error_message(input_size, p, keepdim, dim=None): 2277*da0073e9SAndroid Build Coastguard Worker return f"complex norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}" 2278*da0073e9SAndroid Build Coastguard Worker 2279*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 2280*da0073e9SAndroid Build Coastguard Worker # vector norm 2281*da0073e9SAndroid Build Coastguard Worker x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device) 2282*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 2283*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]: 2284*da0073e9SAndroid Build Coastguard Worker res = x.norm(p, keepdim=keepdim).cpu() 2285*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, p, keepdims=keepdim) 2286*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), p, keepdim) 2287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 2288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg) 2289*da0073e9SAndroid Build Coastguard Worker 2290*da0073e9SAndroid Build Coastguard Worker # matrix norm 2291*da0073e9SAndroid Build Coastguard Worker x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device) 2292*da0073e9SAndroid Build Coastguard Worker xn = x.cpu().numpy() 2293*da0073e9SAndroid Build Coastguard Worker for p in ['nuc', 'fro']: 2294*da0073e9SAndroid Build Coastguard Worker res = x.norm(p, keepdim=keepdim).cpu() 2295*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(xn, p, keepdims=keepdim) 2296*da0073e9SAndroid Build Coastguard Worker msg = gen_error_message(x.size(), p, keepdim) 2297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, expected.shape, msg=msg) 2298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=msg, rtol=4e-6, atol=6e-4) 2299*da0073e9SAndroid Build Coastguard Worker 2300*da0073e9SAndroid Build Coastguard Worker # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations 2301*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 2302*da0073e9SAndroid Build Coastguard Worker def test_norm_fro_2_equivalence_old(self, device, dtype): 2303*da0073e9SAndroid Build Coastguard Worker input_sizes = [ 2304*da0073e9SAndroid Build Coastguard Worker (0,), 2305*da0073e9SAndroid Build Coastguard Worker (10,), 2306*da0073e9SAndroid Build Coastguard Worker (0, 0), 2307*da0073e9SAndroid Build Coastguard Worker (4, 30), 2308*da0073e9SAndroid Build Coastguard Worker (0, 45), 2309*da0073e9SAndroid Build Coastguard Worker (100, 0), 2310*da0073e9SAndroid Build Coastguard Worker (45, 10, 23), 2311*da0073e9SAndroid Build Coastguard Worker (0, 23, 59), 2312*da0073e9SAndroid Build Coastguard Worker (23, 0, 37), 2313*da0073e9SAndroid Build Coastguard Worker (34, 58, 0), 2314*da0073e9SAndroid Build Coastguard Worker (0, 0, 348), 2315*da0073e9SAndroid Build Coastguard Worker (0, 3434, 0), 2316*da0073e9SAndroid Build Coastguard Worker (0, 0, 0), 2317*da0073e9SAndroid Build Coastguard Worker (5, 3, 8, 1, 3, 5)] 2318*da0073e9SAndroid Build Coastguard Worker 2319*da0073e9SAndroid Build Coastguard Worker for input_size in input_sizes: 2320*da0073e9SAndroid Build Coastguard Worker a = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker # Try full reduction 2323*da0073e9SAndroid Build Coastguard Worker dim_settings = [None] 2324*da0073e9SAndroid Build Coastguard Worker 2325*da0073e9SAndroid Build Coastguard Worker # Try all possible 1-D reductions 2326*da0073e9SAndroid Build Coastguard Worker dim_settings += list(range(-a.dim(), a.dim())) 2327*da0073e9SAndroid Build Coastguard Worker 2328*da0073e9SAndroid Build Coastguard Worker def wrap_dim(dim, ndims): 2329*da0073e9SAndroid Build Coastguard Worker assert (dim < ndims) and (dim >= -ndims) 2330*da0073e9SAndroid Build Coastguard Worker if dim >= 0: 2331*da0073e9SAndroid Build Coastguard Worker return dim 2332*da0073e9SAndroid Build Coastguard Worker else: 2333*da0073e9SAndroid Build Coastguard Worker return dim + ndims 2334*da0073e9SAndroid Build Coastguard Worker 2335*da0073e9SAndroid Build Coastguard Worker # Try all possible 2-D reductions 2336*da0073e9SAndroid Build Coastguard Worker dim_settings += [ 2337*da0073e9SAndroid Build Coastguard Worker (d0, d1) for d0, d1 in itertools.combinations(range(-a.dim(), a.dim()), 2) 2338*da0073e9SAndroid Build Coastguard Worker if wrap_dim(d0, a.dim()) != wrap_dim(d1, a.dim())] 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker for dim in dim_settings: 2341*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 2342*da0073e9SAndroid Build Coastguard Worker a_norm_2 = torch.norm(a, p=2, dim=dim, keepdim=keepdim) 2343*da0073e9SAndroid Build Coastguard Worker a_norm_fro = torch.norm(a, p='fro', dim=dim, keepdim=keepdim) 2344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_norm_fro, a_norm_2) 2345*da0073e9SAndroid Build Coastguard Worker 2346*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 2347*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2348*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2349*da0073e9SAndroid Build Coastguard Worker def test_nuclear_norm_axes_small_brute_force_old(self, device): 2350*da0073e9SAndroid Build Coastguard Worker def check_single_nuclear_norm(x, axes): 2351*da0073e9SAndroid Build Coastguard Worker if self.device_type != 'cpu' and randrange(100) < 95: 2352*da0073e9SAndroid Build Coastguard Worker return # too many cpu <==> device copies 2353*da0073e9SAndroid Build Coastguard Worker 2354*da0073e9SAndroid Build Coastguard Worker a = np.array(x.cpu(), copy=False) 2355*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.norm(a, "nuc", axis=axes) 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker ans = torch.norm(x, "nuc", dim=axes) 2358*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ans.is_contiguous()) 2359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans.shape, expected.shape) 2360*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) 2361*da0073e9SAndroid Build Coastguard Worker 2362*da0073e9SAndroid Build Coastguard Worker out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device) 2363*da0073e9SAndroid Build Coastguard Worker ans = torch.norm(x, "nuc", dim=axes, out=out) 2364*da0073e9SAndroid Build Coastguard Worker self.assertIs(ans, out) 2365*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ans.is_contiguous()) 2366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans.shape, expected.shape) 2367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker for n in range(1, 3): 2370*da0073e9SAndroid Build Coastguard Worker for m in range(1, 3): 2371*da0073e9SAndroid Build Coastguard Worker for axes in itertools.permutations([0, 1], 2): 2372*da0073e9SAndroid Build Coastguard Worker # 2d, inner dimensions C 2373*da0073e9SAndroid Build Coastguard Worker x = torch.randn(n, m, device=device) 2374*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2375*da0073e9SAndroid Build Coastguard Worker 2376*da0073e9SAndroid Build Coastguard Worker # 2d, inner dimensions Fortran 2377*da0073e9SAndroid Build Coastguard Worker x = torch.randn(m, n, device=device).mT 2378*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2379*da0073e9SAndroid Build Coastguard Worker 2380*da0073e9SAndroid Build Coastguard Worker # 2d, inner dimensions non-contiguous 2381*da0073e9SAndroid Build Coastguard Worker x = torch.randn(n, 2 * m, device=device)[:, ::2] 2382*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2383*da0073e9SAndroid Build Coastguard Worker 2384*da0073e9SAndroid Build Coastguard Worker # 2d, all dimensions non-contiguous 2385*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2] 2386*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2387*da0073e9SAndroid Build Coastguard Worker 2388*da0073e9SAndroid Build Coastguard Worker for o in range(1, 3): 2389*da0073e9SAndroid Build Coastguard Worker for axes in itertools.permutations([0, 1, 2], 2): 2390*da0073e9SAndroid Build Coastguard Worker # 3d, inner dimensions C 2391*da0073e9SAndroid Build Coastguard Worker x = torch.randn(o, n, m, device=device) 2392*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2393*da0073e9SAndroid Build Coastguard Worker 2394*da0073e9SAndroid Build Coastguard Worker # 3d, inner dimensions Fortran 2395*da0073e9SAndroid Build Coastguard Worker x = torch.randn(o, m, n, device=device).mT 2396*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2397*da0073e9SAndroid Build Coastguard Worker 2398*da0073e9SAndroid Build Coastguard Worker # 3d, inner dimensions non-contiguous 2399*da0073e9SAndroid Build Coastguard Worker x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2] 2400*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2401*da0073e9SAndroid Build Coastguard Worker 2402*da0073e9SAndroid Build Coastguard Worker # 3d, all dimensions non-contiguous 2403*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2] 2404*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2405*da0073e9SAndroid Build Coastguard Worker 2406*da0073e9SAndroid Build Coastguard Worker for r in range(1, 3): 2407*da0073e9SAndroid Build Coastguard Worker for axes in itertools.permutations([0, 1, 2, 3], 2): 2408*da0073e9SAndroid Build Coastguard Worker # 4d, inner dimensions C 2409*da0073e9SAndroid Build Coastguard Worker x = torch.randn(r, o, n, m, device=device) 2410*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2411*da0073e9SAndroid Build Coastguard Worker 2412*da0073e9SAndroid Build Coastguard Worker # 4d, inner dimensions Fortran 2413*da0073e9SAndroid Build Coastguard Worker x = torch.randn(r, o, n, m, device=device).mT 2414*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2415*da0073e9SAndroid Build Coastguard Worker 2416*da0073e9SAndroid Build Coastguard Worker # 4d, inner dimensions non-contiguous 2417*da0073e9SAndroid Build Coastguard Worker x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2] 2418*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2419*da0073e9SAndroid Build Coastguard Worker 2420*da0073e9SAndroid Build Coastguard Worker # 4d, all dimensions non-contiguous 2421*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2] 2422*da0073e9SAndroid Build Coastguard Worker check_single_nuclear_norm(x, axes) 2423*da0073e9SAndroid Build Coastguard Worker 2424*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2425*da0073e9SAndroid Build Coastguard Worker def test_nuclear_norm_exceptions_old(self, device): 2426*da0073e9SAndroid Build Coastguard Worker for lst in [], [1], [1, 2]: 2427*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(lst, dtype=torch.double, device=device) 2428*da0073e9SAndroid Build Coastguard Worker for axes in (), (0,): 2429*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) 2430*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1)) 2431*da0073e9SAndroid Build Coastguard Worker 2432*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) 2433*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0)) 2434*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 2437*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2438*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2439*da0073e9SAndroid Build Coastguard Worker def test_svd_lowrank(self, device, dtype): 2440*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix 2441*da0073e9SAndroid Build Coastguard Worker 2442*da0073e9SAndroid Build Coastguard Worker def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): 2443*da0073e9SAndroid Build Coastguard Worker density = options.pop('density', 1) 2444*da0073e9SAndroid Build Coastguard Worker if isinstance(matrix_size, int): 2445*da0073e9SAndroid Build Coastguard Worker rows = columns = matrix_size 2446*da0073e9SAndroid Build Coastguard Worker else: 2447*da0073e9SAndroid Build Coastguard Worker rows, columns = matrix_size 2448*da0073e9SAndroid Build Coastguard Worker if density == 1: 2449*da0073e9SAndroid Build Coastguard Worker a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) 2450*da0073e9SAndroid Build Coastguard Worker a = a_input 2451*da0073e9SAndroid Build Coastguard Worker else: 2452*da0073e9SAndroid Build Coastguard Worker assert batches == () 2453*da0073e9SAndroid Build Coastguard Worker a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) 2454*da0073e9SAndroid Build Coastguard Worker a = a_input.to_dense() 2455*da0073e9SAndroid Build Coastguard Worker 2456*da0073e9SAndroid Build Coastguard Worker q = min(*size) 2457*da0073e9SAndroid Build Coastguard Worker u, s, v = svd_lowrank(a_input, q=q, **options) 2458*da0073e9SAndroid Build Coastguard Worker 2459*da0073e9SAndroid Build Coastguard Worker # check if u, s, v is a SVD 2460*da0073e9SAndroid Build Coastguard Worker u, s, v = u[..., :q], s[..., :q], v[..., :q] 2461*da0073e9SAndroid Build Coastguard Worker A = (u * s.unsqueeze(-2)).matmul(v.mH) 2462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, a, rtol=1e-7, atol=2e-7) 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker # check if svd_lowrank produces same singular values as linalg.svdvals 2465*da0073e9SAndroid Build Coastguard Worker U, S, Vh = torch.linalg.svd(a, full_matrices=False) 2466*da0073e9SAndroid Build Coastguard Worker V = Vh.mH 2467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, S) 2468*da0073e9SAndroid Build Coastguard Worker 2469*da0073e9SAndroid Build Coastguard Worker if density == 1: 2470*da0073e9SAndroid Build Coastguard Worker # actual_rank is known only for dense inputs 2471*da0073e9SAndroid Build Coastguard Worker # 2472*da0073e9SAndroid Build Coastguard Worker # check if pairs (u, U) and (v, V) span the same 2473*da0073e9SAndroid Build Coastguard Worker # subspaces, respectively 2474*da0073e9SAndroid Build Coastguard Worker u, v = u[..., :actual_rank], v[..., :actual_rank] 2475*da0073e9SAndroid Build Coastguard Worker U, V = U[..., :actual_rank], V[..., :actual_rank] 2476*da0073e9SAndroid Build Coastguard Worker expected_ones = u.mH.matmul(U).det().abs() 2477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_ones, torch.ones_like(expected_ones)) 2478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones)) 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker all_batches = [(), (1,), (3,), (2, 3)] 2481*da0073e9SAndroid Build Coastguard Worker for actual_rank, size, all_batches in [ # noqa: B020 2482*da0073e9SAndroid Build Coastguard Worker (2, (17, 4), all_batches), 2483*da0073e9SAndroid Build Coastguard Worker (4, (17, 4), all_batches), 2484*da0073e9SAndroid Build Coastguard Worker (4, (17, 17), all_batches), 2485*da0073e9SAndroid Build Coastguard Worker (10, (100, 40), all_batches), 2486*da0073e9SAndroid Build Coastguard Worker (7, (1000, 1000), [()]), 2487*da0073e9SAndroid Build Coastguard Worker ]: 2488*da0073e9SAndroid Build Coastguard Worker # dense input 2489*da0073e9SAndroid Build Coastguard Worker for batches in all_batches: 2490*da0073e9SAndroid Build Coastguard Worker run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) 2491*da0073e9SAndroid Build Coastguard Worker if size != size[::-1]: 2492*da0073e9SAndroid Build Coastguard Worker run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) 2493*da0073e9SAndroid Build Coastguard Worker 2494*da0073e9SAndroid Build Coastguard Worker # sparse input 2495*da0073e9SAndroid Build Coastguard Worker for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: 2496*da0073e9SAndroid Build Coastguard Worker for density in [0.005, 0.1]: 2497*da0073e9SAndroid Build Coastguard Worker run_subtest(None, size, (), device, torch.svd_lowrank, density=density) 2498*da0073e9SAndroid Build Coastguard Worker 2499*da0073e9SAndroid Build Coastguard Worker # jitting support 2500*da0073e9SAndroid Build Coastguard Worker jitted = torch.jit.script(torch.svd_lowrank) 2501*da0073e9SAndroid Build Coastguard Worker actual_rank, size, batches = 2, (17, 4), () 2502*da0073e9SAndroid Build Coastguard Worker run_subtest(actual_rank, size, batches, device, jitted) 2503*da0073e9SAndroid Build Coastguard Worker 2504*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2505*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2506*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4}) 2507*da0073e9SAndroid Build Coastguard Worker @setLinalgBackendsToDefaultFinally 2508*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2509*da0073e9SAndroid Build Coastguard Worker @serialTest() 2510*da0073e9SAndroid Build Coastguard Worker def test_svd(self, device, dtype): 2511*da0073e9SAndroid Build Coastguard Worker # tests linalg.svd, svd, linalg.svdvals 2512*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 2513*da0073e9SAndroid Build Coastguard Worker 2514*da0073e9SAndroid Build Coastguard Worker backends = ["default"] 2515*da0073e9SAndroid Build Coastguard Worker 2516*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 2517*da0073e9SAndroid Build Coastguard Worker if torch.cuda.has_magma: 2518*da0073e9SAndroid Build Coastguard Worker backends.append("magma") 2519*da0073e9SAndroid Build Coastguard Worker if has_cusolver() or has_hipsolver(): 2520*da0073e9SAndroid Build Coastguard Worker backends.append("cusolver") 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker ns = (12, 4, 2, 0) 2523*da0073e9SAndroid Build Coastguard Worker batches = ((), (0,), (1,), (2,), (2, 1), (0, 2)) 2524*da0073e9SAndroid Build Coastguard Worker drivers = (None, 'gesvd', 'gesvdj', 'gesvda') 2525*da0073e9SAndroid Build Coastguard Worker 2526*da0073e9SAndroid Build Coastguard Worker for backend in backends: 2527*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_linalg_library(backend) 2528*da0073e9SAndroid Build Coastguard Worker 2529*da0073e9SAndroid Build Coastguard Worker for batch, m, n, driver in product(batches, ns, ns, drivers): 2530*da0073e9SAndroid Build Coastguard Worker if not (backend == 'cusolver' or driver is None): 2531*da0073e9SAndroid Build Coastguard Worker # only test cases below and skip otherwise: 2532*da0073e9SAndroid Build Coastguard Worker # - backend == 'cusolver' (driver can be anything) 2533*da0073e9SAndroid Build Coastguard Worker # - backend != 'cusolver' (driver should only be None) 2534*da0073e9SAndroid Build Coastguard Worker continue 2535*da0073e9SAndroid Build Coastguard Worker 2536*da0073e9SAndroid Build Coastguard Worker shape = batch + (m, n) 2537*da0073e9SAndroid Build Coastguard Worker k = min(m, n) 2538*da0073e9SAndroid Build Coastguard Worker A = make_arg(shape) 2539*da0073e9SAndroid Build Coastguard Worker U, S, Vh = torch.linalg.svd(A, full_matrices=False, driver=driver) 2540*da0073e9SAndroid Build Coastguard Worker self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ Vh, A) 2541*da0073e9SAndroid Build Coastguard Worker 2542*da0073e9SAndroid Build Coastguard Worker U_f, S_f, Vh_f = torch.linalg.svd(A, full_matrices=True, driver=driver) 2543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(S_f, S) 2544*da0073e9SAndroid Build Coastguard Worker self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ Vh_f[..., :k, :], A) 2545*da0073e9SAndroid Build Coastguard Worker 2546*da0073e9SAndroid Build Coastguard Worker S_s = torch.linalg.svdvals(A, driver=driver) 2547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(S_s, S) 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker U, S, V = torch.svd(A, some=True) 2550*da0073e9SAndroid Build Coastguard Worker self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ V.mH, A) 2551*da0073e9SAndroid Build Coastguard Worker 2552*da0073e9SAndroid Build Coastguard Worker U_f, S_f, V_f = torch.svd(A, some=False) 2553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(S_f, S) 2554*da0073e9SAndroid Build Coastguard Worker self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ V_f[..., :k].mH, A) 2555*da0073e9SAndroid Build Coastguard Worker 2556*da0073e9SAndroid Build Coastguard Worker S_s = torch.svd(A, compute_uv=False).S 2557*da0073e9SAndroid Build Coastguard Worker self.assertEqual(S_s, S) 2558*da0073e9SAndroid Build Coastguard Worker 2559*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2560*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2561*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex128) 2562*da0073e9SAndroid Build Coastguard Worker def test_invariance_error_spectral_decompositions(self, device, dtype): 2563*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True) 2564*da0073e9SAndroid Build Coastguard Worker A = make_arg((3, 3)) 2565*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ill-defined"): 2566*da0073e9SAndroid Build Coastguard Worker U, _, Vh = torch.linalg.svd(A, full_matrices=False) 2567*da0073e9SAndroid Build Coastguard Worker (U + Vh).sum().abs().backward() 2568*da0073e9SAndroid Build Coastguard Worker 2569*da0073e9SAndroid Build Coastguard Worker A = make_arg((3, 3)) 2570*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ill-defined"): 2571*da0073e9SAndroid Build Coastguard Worker V = torch.linalg.eig(A).eigenvectors 2572*da0073e9SAndroid Build Coastguard Worker V.sum().abs().backward() 2573*da0073e9SAndroid Build Coastguard Worker 2574*da0073e9SAndroid Build Coastguard Worker A = make_arg((3, 3)) 2575*da0073e9SAndroid Build Coastguard Worker A = A + A.mH 2576*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ill-defined"): 2577*da0073e9SAndroid Build Coastguard Worker Q = torch.linalg.eigh(A).eigenvectors 2578*da0073e9SAndroid Build Coastguard Worker Q.sum().abs().backward() 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case 2581*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) 2582*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2583*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2584*da0073e9SAndroid Build Coastguard Worker def test_svd_memory_allocation(self, device, dtype): 2585*da0073e9SAndroid Build Coastguard Worker # test for https://github.com/pytorch/pytorch/issues/61949 2586*da0073e9SAndroid Build Coastguard Worker # the problem was that tensors of incorrect size were allocated and then narrowed 2587*da0073e9SAndroid Build Coastguard Worker m = 3 2588*da0073e9SAndroid Build Coastguard Worker n = 2**20 2589*da0073e9SAndroid Build Coastguard Worker a = make_tensor((m, n), dtype=dtype, device=device) 2590*da0073e9SAndroid Build Coastguard Worker # the following should run without errors 2591*da0073e9SAndroid Build Coastguard Worker S = torch.linalg.svdvals(a) 2592*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.svd(a, full_matrices=False) 2593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.S, S) 2594*da0073e9SAndroid Build Coastguard Worker 2595*da0073e9SAndroid Build Coastguard Worker def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): 2596*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 2597*da0073e9SAndroid Build Coastguard Worker 2598*da0073e9SAndroid Build Coastguard Worker b = torch.randn(*b_dims, dtype=dtype, device=device) 2599*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(*A_dims, dtype=dtype, device=device) 2600*da0073e9SAndroid Build Coastguard Worker L = torch.cholesky(A, upper=upper) 2601*da0073e9SAndroid Build Coastguard Worker return b, A, L 2602*da0073e9SAndroid Build Coastguard Worker 2603*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2604*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2605*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2606*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2607*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 2608*da0073e9SAndroid Build Coastguard Worker def test_cholesky_solve(self, device, dtype): 2609*da0073e9SAndroid Build Coastguard Worker for (k, n), upper in itertools.product(zip([2, 3, 5], [3, 5, 7]), [True, False]): 2610*da0073e9SAndroid Build Coastguard Worker b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype) 2611*da0073e9SAndroid Build Coastguard Worker x = torch.cholesky_solve(b, L, upper=upper) 2612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, np.matmul(A.cpu(), x.cpu())) 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2615*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2616*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2617*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2618*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 2619*da0073e9SAndroid Build Coastguard Worker def test_cholesky_solve_batched(self, device, dtype): 2620*da0073e9SAndroid Build Coastguard Worker def cholesky_solve_batch_helper(A_dims, b_dims, upper): 2621*da0073e9SAndroid Build Coastguard Worker b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) 2622*da0073e9SAndroid Build Coastguard Worker x_exp_list = [] 2623*da0073e9SAndroid Build Coastguard Worker for i in range(b_dims[0]): 2624*da0073e9SAndroid Build Coastguard Worker x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper)) 2625*da0073e9SAndroid Build Coastguard Worker x_exp = torch.stack(x_exp_list) # Stacked output 2626*da0073e9SAndroid Build Coastguard Worker x_act = torch.cholesky_solve(b, L, upper=upper) # Actual output 2627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_act, x_exp) # Equality check 2628*da0073e9SAndroid Build Coastguard Worker Ax = np.matmul(A.cpu(), x_act.cpu()) 2629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, Ax) # Correctness check 2630*da0073e9SAndroid Build Coastguard Worker 2631*da0073e9SAndroid Build Coastguard Worker for upper, batchsize in itertools.product([True, False], [1, 3, 4]): 2632*da0073e9SAndroid Build Coastguard Worker cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper) 2633*da0073e9SAndroid Build Coastguard Worker 2634*da0073e9SAndroid Build Coastguard Worker @slowTest 2635*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2636*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2637*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2638*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2639*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 2640*da0073e9SAndroid Build Coastguard Worker def test_cholesky_solve_batched_many_batches(self, device, dtype): 2641*da0073e9SAndroid Build Coastguard Worker for A_dims, b_dims in zip([(5, 256, 256), (5,)], [(5, 10), (512, 512, 5, 10)]): 2642*da0073e9SAndroid Build Coastguard Worker for upper in [True, False]: 2643*da0073e9SAndroid Build Coastguard Worker b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) 2644*da0073e9SAndroid Build Coastguard Worker x = torch.cholesky_solve(b, L, upper) 2645*da0073e9SAndroid Build Coastguard Worker Ax = torch.matmul(A, x) 2646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Ax, b.expand_as(Ax)) 2647*da0073e9SAndroid Build Coastguard Worker 2648*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2649*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2650*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2651*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2652*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 2653*da0073e9SAndroid Build Coastguard Worker def test_cholesky_solve_batched_broadcasting(self, device, dtype): 2654*da0073e9SAndroid Build Coastguard Worker from numpy.linalg import solve 2655*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 2656*da0073e9SAndroid Build Coastguard Worker 2657*da0073e9SAndroid Build Coastguard Worker def run_test(A_dims, b_dims, upper): 2658*da0073e9SAndroid Build Coastguard Worker A_matrix_size = A_dims[-1] 2659*da0073e9SAndroid Build Coastguard Worker A_batch_dims = A_dims[:-2] 2660*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(A_matrix_size, *A_batch_dims, 2661*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device='cpu') 2662*da0073e9SAndroid Build Coastguard Worker b = torch.randn(*b_dims, dtype=dtype, device='cpu') 2663*da0073e9SAndroid Build Coastguard Worker x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device) 2664*da0073e9SAndroid Build Coastguard Worker A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device) 2665*da0073e9SAndroid Build Coastguard Worker L = torch.linalg.cholesky(A, upper=upper) 2666*da0073e9SAndroid Build Coastguard Worker x = torch.cholesky_solve(b, L, upper=upper) 2667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_exp) 2668*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/42695 2669*da0073e9SAndroid Build Coastguard Worker x = torch.cholesky_solve(b, L, upper=upper, out=x) 2670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_exp) 2671*da0073e9SAndroid Build Coastguard Worker 2672*da0073e9SAndroid Build Coastguard Worker # test against numpy.linalg.solve 2673*da0073e9SAndroid Build Coastguard Worker for upper in [True, False]: 2674*da0073e9SAndroid Build Coastguard Worker run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper) # no broadcasting 2675*da0073e9SAndroid Build Coastguard Worker run_test((2, 1, 3, 4, 4), (4, 6), upper) # broadcasting b 2676*da0073e9SAndroid Build Coastguard Worker run_test((4, 4), (2, 1, 3, 4, 2), upper) # broadcasting A 2677*da0073e9SAndroid Build Coastguard Worker run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper) # broadcasting A & b 2678*da0073e9SAndroid Build Coastguard Worker 2679*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2680*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2681*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2682*da0073e9SAndroid Build Coastguard Worker def test_cholesky_solve_out_errors_and_warnings(self, device, dtype): 2683*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 2684*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2, dtype=dtype, device=device) 2685*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 1, dtype=dtype, device=device) 2686*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.int, device=device) 2687*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 2688*da0073e9SAndroid Build Coastguard Worker torch.cholesky_solve(b, a, out=out) 2689*da0073e9SAndroid Build Coastguard Worker 2690*da0073e9SAndroid Build Coastguard Worker # device should match 2691*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 2692*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2693*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=dtype, device=wrong_device) 2694*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 2695*da0073e9SAndroid Build Coastguard Worker torch.cholesky_solve(b, a, out=out) 2696*da0073e9SAndroid Build Coastguard Worker 2697*da0073e9SAndroid Build Coastguard Worker # if out tensor with wrong shape is passed a warning is given 2698*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2699*da0073e9SAndroid Build Coastguard Worker out = torch.empty(1, dtype=dtype, device=device) 2700*da0073e9SAndroid Build Coastguard Worker # Trigger warning 2701*da0073e9SAndroid Build Coastguard Worker torch.cholesky_solve(b, a, out=out) 2702*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 2703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 2704*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2705*da0073e9SAndroid Build Coastguard Worker 2706*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2707*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2708*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 2709*da0073e9SAndroid Build Coastguard Worker def test_cholesky_solve_backward(self, device, dtype): 2710*da0073e9SAndroid Build Coastguard Worker b_dims = (5, 2) 2711*da0073e9SAndroid Build Coastguard Worker L_dims = (5, 5) 2712*da0073e9SAndroid Build Coastguard Worker 2713*da0073e9SAndroid Build Coastguard Worker for test_L_grad in (False, True): 2714*da0073e9SAndroid Build Coastguard Worker b = torch.randn(*b_dims, dtype=dtype, device=device, requires_grad=True) 2715*da0073e9SAndroid Build Coastguard Worker L = torch.randn(*L_dims, dtype=dtype, device=device, requires_grad=test_L_grad) 2716*da0073e9SAndroid Build Coastguard Worker if test_L_grad: 2717*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(lambda b, L: torch.cholesky_solve(b, torch.tril(L), upper=False), (b, L)) 2718*da0073e9SAndroid Build Coastguard Worker else: 2719*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(lambda b: torch.cholesky_solve(b, L, upper=False), (b,)) 2720*da0073e9SAndroid Build Coastguard Worker 2721*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2722*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2723*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2724*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3, 2725*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 2726*da0073e9SAndroid Build Coastguard Worker def test_inverse(self, device, dtype): 2727*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 2728*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_fullrank, device=device, dtype=dtype) 2729*da0073e9SAndroid Build Coastguard Worker 2730*da0073e9SAndroid Build Coastguard Worker def run_test(torch_inverse, matrix, batches, n): 2731*da0073e9SAndroid Build Coastguard Worker matrix_inverse = torch_inverse(matrix) 2732*da0073e9SAndroid Build Coastguard Worker 2733*da0073e9SAndroid Build Coastguard Worker # Compare against NumPy output 2734*da0073e9SAndroid Build Coastguard Worker # NumPy uses 'gesv' LAPACK routine solving the equation A A_inv = I 2735*da0073e9SAndroid Build Coastguard Worker # But in PyTorch 'gertf' + 'getrs' is used. As such, there may be some element-wise differences 2736*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.inv(matrix.cpu().numpy()) 2737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=self.precision) 2738*da0073e9SAndroid Build Coastguard Worker 2739*da0073e9SAndroid Build Coastguard Worker # Additional correctness tests, check matrix*matrix_inverse == identity 2740*da0073e9SAndroid Build Coastguard Worker identity = torch.eye(n, dtype=dtype, device=device) 2741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(identity.expand_as(matrix), np.matmul(matrix.cpu(), matrix_inverse.cpu())) 2742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(identity.expand_as(matrix), np.matmul(matrix_inverse.cpu(), matrix.cpu())) 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker # check the out= variant 2745*da0073e9SAndroid Build Coastguard Worker # prepare the expected out tensor 2746*da0073e9SAndroid Build Coastguard Worker matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device) 2747*da0073e9SAndroid Build Coastguard Worker matrix_inverse_out_t = matrix_inverse_out.mT.clone(memory_format=torch.contiguous_format) 2748*da0073e9SAndroid Build Coastguard Worker matrix_inverse_out = matrix_inverse_out_t.mT 2749*da0073e9SAndroid Build Coastguard Worker ans = torch_inverse(matrix, out=matrix_inverse_out) 2750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0) 2751*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0) 2752*da0073e9SAndroid Build Coastguard Worker 2753*da0073e9SAndroid Build Coastguard Worker # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix 2754*da0073e9SAndroid Build Coastguard Worker if matrix.ndim > 2 and batches[0] != 0: 2755*da0073e9SAndroid Build Coastguard Worker expected_inv_list = [] 2756*da0073e9SAndroid Build Coastguard Worker p = int(np.prod(batches)) # use `p` instead of -1, so that the test works for empty input as well 2757*da0073e9SAndroid Build Coastguard Worker for mat in matrix.contiguous().view(p, n, n): 2758*da0073e9SAndroid Build Coastguard Worker expected_inv_list.append(torch_inverse(mat)) 2759*da0073e9SAndroid Build Coastguard Worker expected_inv = torch.stack(expected_inv_list).view(*batches, n, n) 2760*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]: 2761*da0073e9SAndroid Build Coastguard Worker # single-inverse is done using cuSOLVER, while batched inverse is done using MAGMA 2762*da0073e9SAndroid Build Coastguard Worker # individual values can be significantly different for fp32, hence rather high rtol is used 2763*da0073e9SAndroid Build Coastguard Worker # the important thing is that torch_inverse passes above checks with identity 2764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2) 2765*da0073e9SAndroid Build Coastguard Worker else: 2766*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_inverse, expected_inv) 2767*da0073e9SAndroid Build Coastguard Worker 2768*da0073e9SAndroid Build Coastguard Worker # helper function for testing torch.linalg.inv_ex 2769*da0073e9SAndroid Build Coastguard Worker def test_inv_ex(input, out=None): 2770*da0073e9SAndroid Build Coastguard Worker if out is not None: 2771*da0073e9SAndroid Build Coastguard Worker info = torch.empty(0, dtype=torch.int32, device=device) 2772*da0073e9SAndroid Build Coastguard Worker return torch.linalg.inv_ex(input, out=(out, info)).inverse 2773*da0073e9SAndroid Build Coastguard Worker return torch.linalg.inv_ex(input).inverse 2774*da0073e9SAndroid Build Coastguard Worker 2775*da0073e9SAndroid Build Coastguard Worker for torch_inverse in [torch.inverse, torch.linalg.inv, test_inv_ex]: 2776*da0073e9SAndroid Build Coastguard Worker for batches, n in itertools.product( 2777*da0073e9SAndroid Build Coastguard Worker [[], [0], [2], [2, 1]], 2778*da0073e9SAndroid Build Coastguard Worker [0, 5] 2779*da0073e9SAndroid Build Coastguard Worker ): 2780*da0073e9SAndroid Build Coastguard Worker matrices = make_arg(*batches, n, n) 2781*da0073e9SAndroid Build Coastguard Worker run_test(torch_inverse, matrices, batches, n) 2782*da0073e9SAndroid Build Coastguard Worker 2783*da0073e9SAndroid Build Coastguard Worker # test non-contiguous input 2784*da0073e9SAndroid Build Coastguard Worker run_test(torch_inverse, matrices.mT, batches, n) 2785*da0073e9SAndroid Build Coastguard Worker if n > 0: 2786*da0073e9SAndroid Build Coastguard Worker run_test( 2787*da0073e9SAndroid Build Coastguard Worker torch_inverse, 2788*da0073e9SAndroid Build Coastguard Worker make_arg(*batches, 2 * n, 2 * n) 2789*da0073e9SAndroid Build Coastguard Worker .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n), 2790*da0073e9SAndroid Build Coastguard Worker batches, n 2791*da0073e9SAndroid Build Coastguard Worker ) 2792*da0073e9SAndroid Build Coastguard Worker 2793*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2794*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2795*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2796*da0073e9SAndroid Build Coastguard Worker def test_inv_ex_info_device(self, device, dtype): 2797*da0073e9SAndroid Build Coastguard Worker A = torch.eye(3, 3, dtype=dtype, device=device) 2798*da0073e9SAndroid Build Coastguard Worker info = torch.linalg.inv_ex(A).info 2799*da0073e9SAndroid Build Coastguard Worker self.assertTrue(info.device == A.device) 2800*da0073e9SAndroid Build Coastguard Worker 2801*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2802*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2803*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2804*da0073e9SAndroid Build Coastguard Worker def test_inv_ex_singular(self, device, dtype): 2805*da0073e9SAndroid Build Coastguard Worker # if the input matrix is not invertible, info with positive integer is returned 2806*da0073e9SAndroid Build Coastguard Worker A = torch.eye(3, 3, dtype=dtype, device=device) 2807*da0073e9SAndroid Build Coastguard Worker A[-1, -1] = 0 # Now A is singular 2808*da0073e9SAndroid Build Coastguard Worker info = torch.linalg.inv_ex(A).info 2809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(info, 3) 2810*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, 2811*da0073e9SAndroid Build Coastguard Worker r'diagonal element 3 is zero, the inversion could not be completed'): 2812*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv_ex(A, check_errors=True) 2813*da0073e9SAndroid Build Coastguard Worker 2814*da0073e9SAndroid Build Coastguard Worker # if at least one matrix in the batch is not positive definite, 2815*da0073e9SAndroid Build Coastguard Worker # batched info with positive integer for the corresponding matrix is returned 2816*da0073e9SAndroid Build Coastguard Worker A = torch.eye(3, 3, dtype=dtype, device=device) 2817*da0073e9SAndroid Build Coastguard Worker A = A.reshape((1, 3, 3)) 2818*da0073e9SAndroid Build Coastguard Worker A = A.repeat(5, 1, 1) 2819*da0073e9SAndroid Build Coastguard Worker A[3, -2, -2] = 0 # Now A[3] is singular 2820*da0073e9SAndroid Build Coastguard Worker info = torch.linalg.inv_ex(A).info 2821*da0073e9SAndroid Build Coastguard Worker 2822*da0073e9SAndroid Build Coastguard Worker expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) 2823*da0073e9SAndroid Build Coastguard Worker expected_info[3] = 2 2824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(info, expected_info) 2825*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'): 2826*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv_ex(A, check_errors=True) 2827*da0073e9SAndroid Build Coastguard Worker 2828*da0073e9SAndroid Build Coastguard Worker @slowTest 2829*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2830*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2831*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2832*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3, 2833*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-5, torch.complex128: 1e-5}) 2834*da0073e9SAndroid Build Coastguard Worker def test_inverse_many_batches(self, device, dtype): 2835*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 2836*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_fullrank, device=device, dtype=dtype) 2837*da0073e9SAndroid Build Coastguard Worker 2838*da0073e9SAndroid Build Coastguard Worker def test_inverse_many_batches_helper(torch_inverse, b, n): 2839*da0073e9SAndroid Build Coastguard Worker matrices = make_arg(b, n, n) 2840*da0073e9SAndroid Build Coastguard Worker matrices_inverse = torch_inverse(matrices) 2841*da0073e9SAndroid Build Coastguard Worker 2842*da0073e9SAndroid Build Coastguard Worker # Compare against NumPy output 2843*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.inv(matrices.cpu().numpy()) 2844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-3) 2845*da0073e9SAndroid Build Coastguard Worker 2846*da0073e9SAndroid Build Coastguard Worker for torch_inverse in [torch.inverse, torch.linalg.inv]: 2847*da0073e9SAndroid Build Coastguard Worker test_inverse_many_batches_helper(torch_inverse, 5, 256) 2848*da0073e9SAndroid Build Coastguard Worker test_inverse_many_batches_helper(torch_inverse, 3, 512) 2849*da0073e9SAndroid Build Coastguard Worker 2850*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2851*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2852*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # TODO: XLA doesn't raise exception 2853*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2854*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882") 2855*da0073e9SAndroid Build Coastguard Worker def test_inverse_errors(self, device, dtype): 2856*da0073e9SAndroid Build Coastguard Worker # inverse expects batches of square matrices as input 2857*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 2858*da0073e9SAndroid Build Coastguard Worker torch.inverse(torch.randn(2, 3, 4, 3)) 2859*da0073e9SAndroid Build Coastguard Worker 2860*da0073e9SAndroid Build Coastguard Worker # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch 2861*da0073e9SAndroid Build Coastguard Worker def run_test_singular_input(batch_dim, n): 2862*da0073e9SAndroid Build Coastguard Worker x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) 2863*da0073e9SAndroid Build Coastguard Worker x[n, -1, -1] = 0 2864*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): 2865*da0073e9SAndroid Build Coastguard Worker torch.inverse(x) 2866*da0073e9SAndroid Build Coastguard Worker 2867*da0073e9SAndroid Build Coastguard Worker for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: 2868*da0073e9SAndroid Build Coastguard Worker run_test_singular_input(*params) 2869*da0073e9SAndroid Build Coastguard Worker 2870*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra") 2871*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2872*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2873*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # TODO: XLA doesn't raise exception 2874*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2875*da0073e9SAndroid Build Coastguard Worker def test_inverse_errors_large(self, device, dtype): 2876*da0073e9SAndroid Build Coastguard Worker # Test batched inverse of singular matrices reports errors without crashing (gh-51930) 2877*da0073e9SAndroid Build Coastguard Worker x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device) 2878*da0073e9SAndroid Build Coastguard Worker x[:] = torch.eye(616, dtype=dtype, device=device) 2879*da0073e9SAndroid Build Coastguard Worker x[..., 10, 10] = 0 2880*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'): 2881*da0073e9SAndroid Build Coastguard Worker torch.inverse(x) 2882*da0073e9SAndroid Build Coastguard Worker 2883*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7}) 2884*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2885*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2886*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2887*da0073e9SAndroid Build Coastguard Worker def test_pinv(self, device, dtype): 2888*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 2889*da0073e9SAndroid Build Coastguard Worker 2890*da0073e9SAndroid Build Coastguard Worker def run_test_main(A, hermitian): 2891*da0073e9SAndroid Build Coastguard Worker # Testing against definition for pseudo-inverses 2892*da0073e9SAndroid Build Coastguard Worker A_pinv = torch.linalg.pinv(A, hermitian=hermitian) 2893*da0073e9SAndroid Build Coastguard Worker np_A = A.cpu().numpy() 2894*da0073e9SAndroid Build Coastguard Worker np_A_pinv = A_pinv.cpu().numpy() 2895*da0073e9SAndroid Build Coastguard Worker if A.numel() > 0: 2896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=self.precision, rtol=self.precision) 2897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=self.precision, rtol=self.precision) 2898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1)) 2899*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1)) 2900*da0073e9SAndroid Build Coastguard Worker else: 2901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2])) 2902*da0073e9SAndroid Build Coastguard Worker 2903*da0073e9SAndroid Build Coastguard Worker # Check out= variant 2904*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(A_pinv) 2905*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.pinv(A, hermitian=hermitian, out=out) 2906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 2907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, A_pinv) 2908*da0073e9SAndroid Build Coastguard Worker 2909*da0073e9SAndroid Build Coastguard Worker def run_test_numpy(A, hermitian): 2910*da0073e9SAndroid Build Coastguard Worker # Check against NumPy output 2911*da0073e9SAndroid Build Coastguard Worker # Test float rcond, and specific value for each matrix 2912*da0073e9SAndroid Build Coastguard Worker rconds = [float(torch.rand(1)), ] 2913*da0073e9SAndroid Build Coastguard Worker # Test different types of rcond tensor 2914*da0073e9SAndroid Build Coastguard Worker for rcond_type in all_types(): 2915*da0073e9SAndroid Build Coastguard Worker rconds.append(torch.rand(A.shape[:-2], dtype=torch.double, device=device).to(rcond_type)) 2916*da0073e9SAndroid Build Coastguard Worker # Test broadcasting of rcond 2917*da0073e9SAndroid Build Coastguard Worker if A.ndim > 2: 2918*da0073e9SAndroid Build Coastguard Worker rconds.append(torch.rand(A.shape[-3], device=device)) 2919*da0073e9SAndroid Build Coastguard Worker for rcond in rconds: 2920*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian) 2921*da0073e9SAndroid Build Coastguard Worker torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian) 2922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, torch_rtol) 2923*da0073e9SAndroid Build Coastguard Worker numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy() 2924*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian) 2925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=self.precision, rtol=1e-5) 2926*da0073e9SAndroid Build Coastguard Worker 2927*da0073e9SAndroid Build Coastguard Worker for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 2928*da0073e9SAndroid Build Coastguard Worker (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices 2929*da0073e9SAndroid Build Coastguard Worker (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices 2930*da0073e9SAndroid Build Coastguard Worker (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices 2931*da0073e9SAndroid Build Coastguard Worker A = torch.randn(*sizes, dtype=dtype, device=device) 2932*da0073e9SAndroid Build Coastguard Worker hermitian = False 2933*da0073e9SAndroid Build Coastguard Worker run_test_main(A, hermitian) 2934*da0073e9SAndroid Build Coastguard Worker run_test_numpy(A, hermitian) 2935*da0073e9SAndroid Build Coastguard Worker 2936*da0073e9SAndroid Build Coastguard Worker # Check hermitian = True 2937*da0073e9SAndroid Build Coastguard Worker for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 2938*da0073e9SAndroid Build Coastguard Worker (0, 0), (3, 0, 0), ]: # zero numel square matrices 2939*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device) 2940*da0073e9SAndroid Build Coastguard Worker hermitian = True 2941*da0073e9SAndroid Build Coastguard Worker run_test_main(A, hermitian) 2942*da0073e9SAndroid Build Coastguard Worker run_test_numpy(A, hermitian) 2943*da0073e9SAndroid Build Coastguard Worker 2944*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 2945*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2946*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2947*da0073e9SAndroid Build Coastguard Worker def test_pinv_errors_and_warnings(self, device, dtype): 2948*da0073e9SAndroid Build Coastguard Worker # pinv requires at least 2D tensor 2949*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, device=device, dtype=dtype) 2950*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"): 2951*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a) 2952*da0073e9SAndroid Build Coastguard Worker 2953*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 2954*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, dtype=dtype, device=device) 2955*da0073e9SAndroid Build Coastguard Worker out = torch.empty(7, 7, dtype=dtype, device=device) 2956*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2957*da0073e9SAndroid Build Coastguard Worker # Trigger warning 2958*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a, out=out) 2959*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 2960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 2961*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 2962*da0073e9SAndroid Build Coastguard Worker 2963*da0073e9SAndroid Build Coastguard Worker # dtypes of out and input should be safely castable 2964*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a).to(torch.int) 2965*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 2966*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a, out=out) 2967*da0073e9SAndroid Build Coastguard Worker 2968*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 2969*da0073e9SAndroid Build Coastguard Worker # device of out and input should match 2970*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2971*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a).to(wrong_device) 2972*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected result and input tensors to be on the same device"): 2973*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a, out=out) 2974*da0073e9SAndroid Build Coastguard Worker 2975*da0073e9SAndroid Build Coastguard Worker # device of rcond and input should match 2976*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 2977*da0073e9SAndroid Build Coastguard Worker rcond = torch.full((), 1e-2, device=wrong_device) 2978*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 2979*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a, rcond=rcond) 2980*da0073e9SAndroid Build Coastguard Worker 2981*da0073e9SAndroid Build Coastguard Worker # rcond can't be complex 2982*da0073e9SAndroid Build Coastguard Worker rcond = torch.full((), 1j, device=device) 2983*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "rcond tensor of complex type is not supported"): 2984*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a, rcond=rcond) 2985*da0073e9SAndroid Build Coastguard Worker 2986*da0073e9SAndroid Build Coastguard Worker # atol can't be complex 2987*da0073e9SAndroid Build Coastguard Worker atol = torch.full((), 1j, device=device) 2988*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "atol tensor of complex type is not supported"): 2989*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a, atol=atol) 2990*da0073e9SAndroid Build Coastguard Worker 2991*da0073e9SAndroid Build Coastguard Worker # rtol can't be complex 2992*da0073e9SAndroid Build Coastguard Worker rtol = torch.full((), 1j, device=device) 2993*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "rtol tensor of complex type is not supported"): 2994*da0073e9SAndroid Build Coastguard Worker torch.linalg.pinv(a, rtol=rtol) 2995*da0073e9SAndroid Build Coastguard Worker 2996*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 2997*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 2998*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2999*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882") 3000*da0073e9SAndroid Build Coastguard Worker def test_inv_errors_and_warnings(self, device, dtype): 3001*da0073e9SAndroid Build Coastguard Worker # inv expects batches of square matrices as input 3002*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device) 3003*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 3004*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv(a) 3005*da0073e9SAndroid Build Coastguard Worker 3006*da0073e9SAndroid Build Coastguard Worker # inv requires the input to be at least 2 dimensional tensor 3007*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, device=device, dtype=dtype) 3008*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 3009*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv(a) 3010*da0073e9SAndroid Build Coastguard Worker 3011*da0073e9SAndroid Build Coastguard Worker # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch 3012*da0073e9SAndroid Build Coastguard Worker def run_test_singular_input(batch_dim, n): 3013*da0073e9SAndroid Build Coastguard Worker a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) 3014*da0073e9SAndroid Build Coastguard Worker a[n, -1, -1] = 0 3015*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"): 3016*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv(a) 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: 3019*da0073e9SAndroid Build Coastguard Worker run_test_singular_input(*params) 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker # dtypes should match 3022*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2, dtype=dtype, device=device) 3023*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.int, device=device) 3024*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got int instead"): 3025*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv(a, out=out) 3026*da0073e9SAndroid Build Coastguard Worker 3027*da0073e9SAndroid Build Coastguard Worker # device should match 3028*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3029*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3030*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=wrong_device, dtype=dtype) 3031*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3032*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv(a, out=out) 3033*da0073e9SAndroid Build Coastguard Worker 3034*da0073e9SAndroid Build Coastguard Worker # if out tensor with wrong shape is passed a warning is given 3035*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3036*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2, dtype=dtype, device=device) 3037*da0073e9SAndroid Build Coastguard Worker out = torch.empty(1, dtype=dtype, device=device) 3038*da0073e9SAndroid Build Coastguard Worker # Trigger warning 3039*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv(a, out=out) 3040*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 3041*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 3042*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3043*da0073e9SAndroid Build Coastguard Worker 3044*da0073e9SAndroid Build Coastguard Worker # if out tensor in batched column major format but with wrong a warning is given 3045*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3046*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2, dtype=dtype, device=device) 3047*da0073e9SAndroid Build Coastguard Worker out = torch.empty(3, 3, dtype=dtype, device=device) 3048*da0073e9SAndroid Build Coastguard Worker out = out.mT.clone(memory_format=torch.contiguous_format) 3049*da0073e9SAndroid Build Coastguard Worker out = out.mT 3050*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.mT.is_contiguous()) 3051*da0073e9SAndroid Build Coastguard Worker # Trigger warning 3052*da0073e9SAndroid Build Coastguard Worker torch.linalg.inv(a, out=out) 3053*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 3054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 3055*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3056*da0073e9SAndroid Build Coastguard Worker 3057*da0073e9SAndroid Build Coastguard Worker def solve_test_helper(self, A_dims, b_dims, device, dtype): 3058*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 3059*da0073e9SAndroid Build Coastguard Worker make_A = partial(make_fullrank, device=device, dtype=dtype) 3060*da0073e9SAndroid Build Coastguard Worker 3061*da0073e9SAndroid Build Coastguard Worker b = torch.randn(*b_dims, dtype=dtype, device=device) 3062*da0073e9SAndroid Build Coastguard Worker A = make_A(*A_dims) 3063*da0073e9SAndroid Build Coastguard Worker return b, A 3064*da0073e9SAndroid Build Coastguard Worker 3065*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3066*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3067*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3068*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3}) 3069*da0073e9SAndroid Build Coastguard Worker def test_solve(self, device, dtype): 3070*da0073e9SAndroid Build Coastguard Worker def run_test(n, batch, rhs): 3071*da0073e9SAndroid Build Coastguard Worker A_dims = (*batch, n, n) 3072*da0073e9SAndroid Build Coastguard Worker b_dims = (*batch, n, *rhs) 3073*da0073e9SAndroid Build Coastguard Worker b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) 3074*da0073e9SAndroid Build Coastguard Worker 3075*da0073e9SAndroid Build Coastguard Worker # Correctness test 3076*da0073e9SAndroid Build Coastguard Worker x = torch.linalg.solve(A, b) 3077*da0073e9SAndroid Build Coastguard Worker if rhs == (): 3078*da0073e9SAndroid Build Coastguard Worker Ax = np.matmul(A.cpu(), x.unsqueeze(-1).cpu()) 3079*da0073e9SAndroid Build Coastguard Worker Ax.squeeze_(-1) 3080*da0073e9SAndroid Build Coastguard Worker else: 3081*da0073e9SAndroid Build Coastguard Worker Ax = np.matmul(A.cpu(), x.cpu()) 3082*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.expand_as(Ax), Ax) 3083*da0073e9SAndroid Build Coastguard Worker 3084*da0073e9SAndroid Build Coastguard Worker # Check against NumPy 3085*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy()) 3086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, expected) 3087*da0073e9SAndroid Build Coastguard Worker 3088*da0073e9SAndroid Build Coastguard Worker batches = [(), (0, ), (3, ), (2, 3)] 3089*da0073e9SAndroid Build Coastguard Worker ns = [0, 5, 32] 3090*da0073e9SAndroid Build Coastguard Worker nrhs = [(), (1, ), (5, )] 3091*da0073e9SAndroid Build Coastguard Worker for n, batch, rhs in itertools.product(ns, batches, nrhs): 3092*da0073e9SAndroid Build Coastguard Worker run_test(n, batch, rhs) 3093*da0073e9SAndroid Build Coastguard Worker 3094*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 3095*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3096*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3097*da0073e9SAndroid Build Coastguard Worker def test_solve_batched_broadcasting(self, device, dtype): 3098*da0073e9SAndroid Build Coastguard Worker from numpy.linalg import solve 3099*da0073e9SAndroid Build Coastguard Worker 3100*da0073e9SAndroid Build Coastguard Worker def run_test(A_dims, B_dims): 3101*da0073e9SAndroid Build Coastguard Worker A_matrix_size = A_dims[-1] 3102*da0073e9SAndroid Build Coastguard Worker A_batch_dims = A_dims[:-2] 3103*da0073e9SAndroid Build Coastguard Worker B, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), B_dims, device, dtype) 3104*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.solve(A, B) 3105*da0073e9SAndroid Build Coastguard Worker expected = solve(A.cpu().numpy(), B.cpu().numpy()) 3106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 3107*da0073e9SAndroid Build Coastguard Worker 3108*da0073e9SAndroid Build Coastguard Worker # test against numpy.linalg.solve 3109*da0073e9SAndroid Build Coastguard Worker run_test((5, 5), (2, 0, 5, 3)) # broadcasting with 0 batch dim 3110*da0073e9SAndroid Build Coastguard Worker run_test((2, 0, 5, 5), (5, 3)) # broadcasting with 0 batch dim 3111*da0073e9SAndroid Build Coastguard Worker run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting B 3112*da0073e9SAndroid Build Coastguard Worker run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A 3113*da0073e9SAndroid Build Coastguard Worker run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & B 3114*da0073e9SAndroid Build Coastguard Worker 3115*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3116*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3117*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3118*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) 3119*da0073e9SAndroid Build Coastguard Worker def test_tensorsolve(self, device, dtype): 3120*da0073e9SAndroid Build Coastguard Worker def run_test(a_shape, dims): 3121*da0073e9SAndroid Build Coastguard Worker a = torch.randn(a_shape, dtype=dtype, device=device) 3122*da0073e9SAndroid Build Coastguard Worker b = torch.randn(a_shape[:2], dtype=dtype, device=device) 3123*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.tensorsolve(a, b, dims=dims) 3124*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) 3125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 3126*da0073e9SAndroid Build Coastguard Worker 3127*da0073e9SAndroid Build Coastguard Worker # check the out= variant 3128*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(result) 3129*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) 3130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 3131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, result) 3132*da0073e9SAndroid Build Coastguard Worker 3133*da0073e9SAndroid Build Coastguard Worker a_shapes = [(2, 3, 6), (3, 4, 4, 3)] 3134*da0073e9SAndroid Build Coastguard Worker dims = [None, (0, 2)] 3135*da0073e9SAndroid Build Coastguard Worker for a_shape, d in itertools.product(a_shapes, dims): 3136*da0073e9SAndroid Build Coastguard Worker run_test(a_shape, d) 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3139*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3140*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3141*da0073e9SAndroid Build Coastguard Worker def test_tensorsolve_empty(self, device, dtype): 3142*da0073e9SAndroid Build Coastguard Worker # Check for empty inputs. NumPy does not work for these cases. 3143*da0073e9SAndroid Build Coastguard Worker a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) 3144*da0073e9SAndroid Build Coastguard Worker b = torch.empty(a.shape[:2], dtype=dtype, device=device) 3145*da0073e9SAndroid Build Coastguard Worker x = torch.linalg.tensorsolve(a, b) 3146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b) 3147*da0073e9SAndroid Build Coastguard Worker 3148*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3149*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3150*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 3151*da0073e9SAndroid Build Coastguard Worker def test_tensorsolve_errors_and_warnings(self, device, dtype): 3152*da0073e9SAndroid Build Coastguard Worker # tensorsolve expects the input that can be reshaped to a square matrix 3153*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2 * 3 * 4, dtype=dtype, device=device).reshape((2 * 3, 4, 2, 3, 4)) 3154*da0073e9SAndroid Build Coastguard Worker b = torch.randn(8, 4, dtype=dtype, device=device) 3155*da0073e9SAndroid Build Coastguard Worker self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape)) 3156*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'): 3157*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorsolve(a, b) 3158*da0073e9SAndroid Build Coastguard Worker 3159*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 3160*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a) 3161*da0073e9SAndroid Build Coastguard Worker b = torch.randn(6, 4, dtype=dtype, device=device) 3162*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3163*da0073e9SAndroid Build Coastguard Worker # Trigger warning 3164*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorsolve(a, b, out=out) 3165*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 3166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 3167*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3168*da0073e9SAndroid Build Coastguard Worker 3169*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 3170*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a).to(torch.int) 3171*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 3172*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorsolve(a, b, out=out) 3173*da0073e9SAndroid Build Coastguard Worker 3174*da0073e9SAndroid Build Coastguard Worker # device should match 3175*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3176*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3177*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=dtype, device=wrong_device) 3178*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3179*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorsolve(a, b, out=out) 3180*da0073e9SAndroid Build Coastguard Worker 3181*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3182*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3183*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3184*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3}) 3185*da0073e9SAndroid Build Coastguard Worker def test_tensorinv(self, device, dtype): 3186*da0073e9SAndroid Build Coastguard Worker 3187*da0073e9SAndroid Build Coastguard Worker def run_test(a_shape, ind): 3188*da0073e9SAndroid Build Coastguard Worker a = torch.randn(a_shape, dtype=dtype, device=device) 3189*da0073e9SAndroid Build Coastguard Worker a_numpy = a.cpu().numpy() 3190*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.tensorinv(a, ind=ind) 3191*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.tensorinv(a_numpy, ind=ind) 3192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 3193*da0073e9SAndroid Build Coastguard Worker 3194*da0073e9SAndroid Build Coastguard Worker # check the out= variant 3195*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(result) 3196*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.tensorinv(a, ind=ind, out=out) 3197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 3198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, result) 3199*da0073e9SAndroid Build Coastguard Worker 3200*da0073e9SAndroid Build Coastguard Worker # compare to NumPy output 3201*da0073e9SAndroid Build Coastguard Worker run_test((12, 3, 4), ind=1) 3202*da0073e9SAndroid Build Coastguard Worker run_test((3, 8, 24), ind=2) 3203*da0073e9SAndroid Build Coastguard Worker run_test((18, 3, 3, 2), ind=1) 3204*da0073e9SAndroid Build Coastguard Worker run_test((1, 4, 2, 2), ind=2) 3205*da0073e9SAndroid Build Coastguard Worker run_test((2, 3, 5, 30), ind=3) 3206*da0073e9SAndroid Build Coastguard Worker run_test((24, 2, 2, 3, 2), ind=1) 3207*da0073e9SAndroid Build Coastguard Worker run_test((3, 4, 2, 3, 2), ind=2) 3208*da0073e9SAndroid Build Coastguard Worker run_test((1, 2, 3, 2, 3), ind=3) 3209*da0073e9SAndroid Build Coastguard Worker run_test((3, 2, 1, 2, 12), ind=4) 3210*da0073e9SAndroid Build Coastguard Worker 3211*da0073e9SAndroid Build Coastguard Worker @skipMeta # See https://github.com/pytorch/pytorch/issues/53739 3212*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3213*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3214*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3215*da0073e9SAndroid Build Coastguard Worker def test_tensorinv_empty(self, device, dtype): 3216*da0073e9SAndroid Build Coastguard Worker for ind in range(1, 4): 3217*da0073e9SAndroid Build Coastguard Worker # Check for empty inputs. NumPy does not work for these cases. 3218*da0073e9SAndroid Build Coastguard Worker a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) 3219*da0073e9SAndroid Build Coastguard Worker a_inv = torch.linalg.tensorinv(a, ind=ind) 3220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind]) 3221*da0073e9SAndroid Build Coastguard Worker 3222*da0073e9SAndroid Build Coastguard Worker @skipMeta # See https://github.com/pytorch/pytorch/issues/53739 3223*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3224*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3225*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3226*da0073e9SAndroid Build Coastguard Worker def test_tensorinv_errors_and_warnings(self, device, dtype): 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Worker def check_shape(a_shape, ind): 3229*da0073e9SAndroid Build Coastguard Worker # tensorinv requires the input to satisfy 3230*da0073e9SAndroid Build Coastguard Worker # prod(a.shape[ind:]) == prod(a.shape[:ind]) 3231*da0073e9SAndroid Build Coastguard Worker a = torch.randn(a_shape, dtype=dtype, device=device) 3232*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected self to satisfy the requirement"): 3233*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorinv(a, ind=ind) 3234*da0073e9SAndroid Build Coastguard Worker 3235*da0073e9SAndroid Build Coastguard Worker def check_ind(a_shape, ind): 3236*da0073e9SAndroid Build Coastguard Worker a = torch.randn(a_shape, dtype=dtype, device=device) 3237*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected a strictly positive integer"): 3238*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorinv(a, ind=ind) 3239*da0073e9SAndroid Build Coastguard Worker 3240*da0073e9SAndroid Build Coastguard Worker def check_out(a_shape, ind): 3241*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 3242*da0073e9SAndroid Build Coastguard Worker a = torch.randn(a_shape, dtype=dtype, device=device) 3243*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a) 3244*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3245*da0073e9SAndroid Build Coastguard Worker # Trigger warning 3246*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorinv(a, ind=ind, out=out) 3247*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 3248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 3249*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3250*da0073e9SAndroid Build Coastguard Worker 3251*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 3252*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.int, device=device) 3253*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 3254*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorinv(a, ind=ind, out=out) 3255*da0073e9SAndroid Build Coastguard Worker 3256*da0073e9SAndroid Build Coastguard Worker # device should match 3257*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3258*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3259*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=dtype, device=wrong_device) 3260*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3261*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorinv(a, ind=ind, out=out) 3262*da0073e9SAndroid Build Coastguard Worker 3263*da0073e9SAndroid Build Coastguard Worker # test for invalid shape 3264*da0073e9SAndroid Build Coastguard Worker check_shape((2, 3, 4), ind=1) 3265*da0073e9SAndroid Build Coastguard Worker check_shape((1, 2, 3, 4), ind=3) 3266*da0073e9SAndroid Build Coastguard Worker 3267*da0073e9SAndroid Build Coastguard Worker # test for invalid ind 3268*da0073e9SAndroid Build Coastguard Worker check_ind((12, 3, 4), ind=-1) 3269*da0073e9SAndroid Build Coastguard Worker check_ind((18, 3, 3, 2), ind=0) 3270*da0073e9SAndroid Build Coastguard Worker 3271*da0073e9SAndroid Build Coastguard Worker # test for invalid out tensor 3272*da0073e9SAndroid Build Coastguard Worker check_out((12, 3, 4), ind=1) 3273*da0073e9SAndroid Build Coastguard Worker check_out((3, 8, 24), ind=2) 3274*da0073e9SAndroid Build Coastguard Worker 3275*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3276*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3277*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3278*da0073e9SAndroid Build Coastguard Worker def test_tensorinv_singular_input(self, device, dtype): 3279*da0073e9SAndroid Build Coastguard Worker 3280*da0073e9SAndroid Build Coastguard Worker def check_singular_input(a_shape, ind): 3281*da0073e9SAndroid Build Coastguard Worker prod_ind_end = np.prod(a_shape[ind:]) 3282*da0073e9SAndroid Build Coastguard Worker a = torch.eye(prod_ind_end, dtype=dtype, device=device) 3283*da0073e9SAndroid Build Coastguard Worker a[-1, -1] = 0 # Now `a` is singular 3284*da0073e9SAndroid Build Coastguard Worker a = a.reshape(a_shape) 3285*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"): 3286*da0073e9SAndroid Build Coastguard Worker torch.linalg.tensorinv(a, ind=ind) 3287*da0073e9SAndroid Build Coastguard Worker 3288*da0073e9SAndroid Build Coastguard Worker # test for non-invertible input 3289*da0073e9SAndroid Build Coastguard Worker check_singular_input((12, 3, 4), ind=1) 3290*da0073e9SAndroid Build Coastguard Worker check_singular_input((3, 6, 18), ind=2) 3291*da0073e9SAndroid Build Coastguard Worker 3292*da0073e9SAndroid Build Coastguard Worker def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn): 3293*da0073e9SAndroid Build Coastguard Worker def check(x, y): 3294*da0073e9SAndroid Build Coastguard Worker # Compare with numpy 3295*da0073e9SAndroid Build Coastguard Worker res = torch_fn(x, y) 3296*da0073e9SAndroid Build Coastguard Worker if x.dtype == torch.bfloat16: 3297*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy()))) 3298*da0073e9SAndroid Build Coastguard Worker else: 3299*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy()))) 3300*da0073e9SAndroid Build Coastguard Worker if res.dtype == torch.bfloat16: 3301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu(), ref.bfloat16()) 3302*da0073e9SAndroid Build Coastguard Worker else: 3303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu(), ref) 3304*da0073e9SAndroid Build Coastguard Worker 3305*da0073e9SAndroid Build Coastguard Worker # Test out variant 3306*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(res) 3307*da0073e9SAndroid Build Coastguard Worker torch_fn(x, y, out=out) 3308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, res) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker # Empty 3311*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([], dtype=dtype, device=device) 3312*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([], dtype=dtype, device=device) 3313*da0073e9SAndroid Build Coastguard Worker check(x, y) 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Worker # Contiguous 3316*da0073e9SAndroid Build Coastguard Worker x = 0.1 * torch.randn(5000, dtype=dtype, device=device) 3317*da0073e9SAndroid Build Coastguard Worker y = 0.1 * torch.randn(5000, dtype=dtype, device=device) 3318*da0073e9SAndroid Build Coastguard Worker check(x, y) 3319*da0073e9SAndroid Build Coastguard Worker 3320*da0073e9SAndroid Build Coastguard Worker # 0 strided 3321*da0073e9SAndroid Build Coastguard Worker y = 0.1 * torch.randn(1, dtype=dtype, device=device).expand(5000) 3322*da0073e9SAndroid Build Coastguard Worker check(x, y) 3323*da0073e9SAndroid Build Coastguard Worker 3324*da0073e9SAndroid Build Coastguard Worker # 2 strided 3325*da0073e9SAndroid Build Coastguard Worker check(x[::2], y[::2]) 3326*da0073e9SAndroid Build Coastguard Worker 3327*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.cfloat, torch.bfloat16, torch.float16) 3328*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.cfloat) 3329*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0}) 3330*da0073e9SAndroid Build Coastguard Worker def test_dot_vs_numpy(self, device, dtype): 3331*da0073e9SAndroid Build Coastguard Worker self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot) 3332*da0073e9SAndroid Build Coastguard Worker 3333*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.cfloat) 3334*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) 3335*da0073e9SAndroid Build Coastguard Worker def test_vdot_vs_numpy(self, device, dtype): 3336*da0073e9SAndroid Build Coastguard Worker self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot) 3337*da0073e9SAndroid Build Coastguard Worker 3338*da0073e9SAndroid Build Coastguard Worker def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False): 3339*da0073e9SAndroid Build Coastguard Worker def check(x, y, regex): 3340*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, regex): 3341*da0073e9SAndroid Build Coastguard Worker torch_fn(x, y) 3342*da0073e9SAndroid Build Coastguard Worker 3343*da0073e9SAndroid Build Coastguard Worker if complex_dtypes: 3344*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, dtype=torch.cfloat, device=device) 3345*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, dtype=torch.cdouble, device=device) 3346*da0073e9SAndroid Build Coastguard Worker else: 3347*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, dtype=torch.float, device=device) 3348*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, dtype=torch.double, device=device) 3349*da0073e9SAndroid Build Coastguard Worker 3350*da0073e9SAndroid Build Coastguard Worker check(x, y, 'dot : expected both vectors to have same dtype') 3351*da0073e9SAndroid Build Coastguard Worker check(x.reshape(1, 1), y, '1D tensors expected') 3352*da0073e9SAndroid Build Coastguard Worker check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size') 3353*da0073e9SAndroid Build Coastguard Worker 3354*da0073e9SAndroid Build Coastguard Worker if self.device_type != 'cpu': 3355*da0073e9SAndroid Build Coastguard Worker x_cpu = x.expand(3).cpu() 3356*da0073e9SAndroid Build Coastguard Worker check(x_cpu, y.to(x.dtype), 'Expected all tensors to be on the same device') 3357*da0073e9SAndroid Build Coastguard Worker 3358*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3359*da0073e9SAndroid Build Coastguard Worker def test_vdot_invalid_args(self, device): 3360*da0073e9SAndroid Build Coastguard Worker self._test_dot_vdot_invalid_args(device, torch.vdot) 3361*da0073e9SAndroid Build Coastguard Worker self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True) 3362*da0073e9SAndroid Build Coastguard Worker 3363*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3364*da0073e9SAndroid Build Coastguard Worker def test_dot_invalid_args(self, device): 3365*da0073e9SAndroid Build Coastguard Worker self._test_dot_vdot_invalid_args(device, torch.dot) 3366*da0073e9SAndroid Build Coastguard Worker self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True) 3367*da0073e9SAndroid Build Coastguard Worker 3368*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3369*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3370*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3371*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank(self, device, dtype): 3372*da0073e9SAndroid Build Coastguard Worker matrix_rank = torch.linalg.matrix_rank 3373*da0073e9SAndroid Build Coastguard Worker 3374*da0073e9SAndroid Build Coastguard Worker def run_test(shape0, shape1, batch): 3375*da0073e9SAndroid Build Coastguard Worker a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) 3376*da0073e9SAndroid Build Coastguard Worker rank_a = matrix_rank(a) 3377*da0073e9SAndroid Build Coastguard Worker 3378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_a, matrix_rank(a.mH)) 3379*da0073e9SAndroid Build Coastguard Worker aaH = torch.matmul(a, a.mH) 3380*da0073e9SAndroid Build Coastguard Worker rank_aaH = matrix_rank(aaH) 3381*da0073e9SAndroid Build Coastguard Worker rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) 3382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH, rank_aaH_hermitian) 3383*da0073e9SAndroid Build Coastguard Worker aHa = torch.matmul(a.mH, a) 3384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) 3385*da0073e9SAndroid Build Coastguard Worker 3386*da0073e9SAndroid Build Coastguard Worker # check against NumPy 3387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy())) 3388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01)) 3389*da0073e9SAndroid Build Coastguard Worker 3390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy())) 3391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01)) 3392*da0073e9SAndroid Build Coastguard Worker 3393*da0073e9SAndroid Build Coastguard Worker # hermitian flag for NumPy was added in 1.14.0 3394*da0073e9SAndroid Build Coastguard Worker if np.lib.NumpyVersion(np.__version__) >= '1.14.0': 3395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH_hermitian, 3396*da0073e9SAndroid Build Coastguard Worker np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True)) 3397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aaH, 0.01, True), 3398*da0073e9SAndroid Build Coastguard Worker np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True)) 3399*da0073e9SAndroid Build Coastguard Worker 3400*da0073e9SAndroid Build Coastguard Worker # check out= variant 3401*da0073e9SAndroid Build Coastguard Worker out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device) 3402*da0073e9SAndroid Build Coastguard Worker ans = matrix_rank(a, out=out) 3403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 3404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, rank_a) 3405*da0073e9SAndroid Build Coastguard Worker 3406*da0073e9SAndroid Build Coastguard Worker shapes = (3, 13) 3407*da0073e9SAndroid Build Coastguard Worker batches = ((), (0, ), (4, ), (3, 5, )) 3408*da0073e9SAndroid Build Coastguard Worker for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches): 3409*da0073e9SAndroid Build Coastguard Worker run_test(shape0, shape1, batch) 3410*da0073e9SAndroid Build Coastguard Worker 3411*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3412*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3413*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3414*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank_atol(self, device, dtype): 3415*da0073e9SAndroid Build Coastguard Worker 3416*da0073e9SAndroid Build Coastguard Worker def run_test_atol(shape0, shape1, batch): 3417*da0073e9SAndroid Build Coastguard Worker a = make_tensor((*batch, shape0, shape1), dtype=dtype, device=device) 3418*da0073e9SAndroid Build Coastguard Worker # Check against NumPy output 3419*da0073e9SAndroid Build Coastguard Worker # Test float tol, and specific value for each matrix 3420*da0073e9SAndroid Build Coastguard Worker tolerances = [float(torch.rand(1)), ] 3421*da0073e9SAndroid Build Coastguard Worker # Test different types of tol tensor 3422*da0073e9SAndroid Build Coastguard Worker for tol_type in all_types(): 3423*da0073e9SAndroid Build Coastguard Worker tolerances.append(make_tensor(a.shape[:-2], dtype=tol_type, device=device, low=0)) 3424*da0073e9SAndroid Build Coastguard Worker # Test broadcasting of tol 3425*da0073e9SAndroid Build Coastguard Worker if a.ndim > 2: 3426*da0073e9SAndroid Build Coastguard Worker tolerances.append(make_tensor(a.shape[-3], dtype=torch.float32, device=device, low=0)) 3427*da0073e9SAndroid Build Coastguard Worker for tol in tolerances: 3428*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.matrix_rank(a, atol=tol) 3429*da0073e9SAndroid Build Coastguard Worker actual_tol = torch.linalg.matrix_rank(a, tol=tol) 3430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, actual_tol) 3431*da0073e9SAndroid Build Coastguard Worker numpy_tol = tol if isinstance(tol, float) else tol.cpu().numpy() 3432*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.matrix_rank(a.cpu().numpy(), tol=numpy_tol) 3433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 3434*da0073e9SAndroid Build Coastguard Worker 3435*da0073e9SAndroid Build Coastguard Worker shapes = (3, 13) 3436*da0073e9SAndroid Build Coastguard Worker batches = ((), (0, ), (4, ), (3, 5, )) 3437*da0073e9SAndroid Build Coastguard Worker for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches): 3438*da0073e9SAndroid Build Coastguard Worker run_test_atol(shape0, shape1, batch) 3439*da0073e9SAndroid Build Coastguard Worker 3440*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3441*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3442*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64) 3443*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank_atol_rtol(self, device, dtype): 3444*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 3445*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_fullrank, device=device, dtype=dtype) 3446*da0073e9SAndroid Build Coastguard Worker 3447*da0073e9SAndroid Build Coastguard Worker # creates a matrix with singular values rank=n and singular values in range [2/3, 3/2] 3448*da0073e9SAndroid Build Coastguard Worker # the singular values are 1 + 1/2, 1 - 1/3, 1 + 1/4, 1 - 1/5, ... 3449*da0073e9SAndroid Build Coastguard Worker n = 9 3450*da0073e9SAndroid Build Coastguard Worker a = make_arg(n, n) 3451*da0073e9SAndroid Build Coastguard Worker 3452*da0073e9SAndroid Build Coastguard Worker # test float and tensor variants 3453*da0073e9SAndroid Build Coastguard Worker for tol_value in [0.81, torch.tensor(0.81, device=device)]: 3454*da0073e9SAndroid Build Coastguard Worker # using rtol (relative tolerance) takes into account the largest singular value (1.5 in this case) 3455*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.matrix_rank(a, rtol=tol_value) 3456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, 2) # there are 2 singular values above 1.5*0.81 = 1.215 3457*da0073e9SAndroid Build Coastguard Worker 3458*da0073e9SAndroid Build Coastguard Worker # atol is used directly to compare with singular values 3459*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.matrix_rank(a, atol=tol_value) 3460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, 7) # there are 7 singular values above 0.81 3461*da0073e9SAndroid Build Coastguard Worker 3462*da0073e9SAndroid Build Coastguard Worker # when both are specified the maximum tolerance is used 3463*da0073e9SAndroid Build Coastguard Worker result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value) 3464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, 2) # there are 2 singular values above max(0.81, 1.5*0.81) 3465*da0073e9SAndroid Build Coastguard Worker 3466*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3467*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3468*da0073e9SAndroid Build Coastguard Worker @skipCUDAVersionIn([(11, 6), (11, 7)]) # https://github.com/pytorch/pytorch/issues/75391 3469*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3470*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank_empty(self, device, dtype): 3471*da0073e9SAndroid Build Coastguard Worker matrix_rank = torch.linalg.matrix_rank 3472*da0073e9SAndroid Build Coastguard Worker 3473*da0073e9SAndroid Build Coastguard Worker # NumPy doesn't work for input with no elements 3474*da0073e9SAndroid Build Coastguard Worker def run_test(shape0, shape1, batch): 3475*da0073e9SAndroid Build Coastguard Worker a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) 3476*da0073e9SAndroid Build Coastguard Worker rank_a = matrix_rank(a) 3477*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(batch, dtype=torch.int64, device=device) 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_a, matrix_rank(a.mH)) 3480*da0073e9SAndroid Build Coastguard Worker 3481*da0073e9SAndroid Build Coastguard Worker aaH = torch.matmul(a, a.mH) 3482*da0073e9SAndroid Build Coastguard Worker rank_aaH = matrix_rank(aaH) 3483*da0073e9SAndroid Build Coastguard Worker rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) 3484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH, rank_aaH_hermitian) 3485*da0073e9SAndroid Build Coastguard Worker 3486*da0073e9SAndroid Build Coastguard Worker aHa = torch.matmul(a.mH, a) 3487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) 3488*da0073e9SAndroid Build Coastguard Worker 3489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_a, expected) 3490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(a, 0.01), expected) 3491*da0073e9SAndroid Build Coastguard Worker 3492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH, expected) 3493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aaH, 0.01), expected) 3494*da0073e9SAndroid Build Coastguard Worker 3495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH_hermitian, expected) 3496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aaH, 0.01, True), expected) 3497*da0073e9SAndroid Build Coastguard Worker 3498*da0073e9SAndroid Build Coastguard Worker batches = ((), (4, ), (3, 5, )) 3499*da0073e9SAndroid Build Coastguard Worker for batch in batches: 3500*da0073e9SAndroid Build Coastguard Worker run_test(0, 0, batch) 3501*da0073e9SAndroid Build Coastguard Worker run_test(0, 3, batch) 3502*da0073e9SAndroid Build Coastguard Worker run_test(3, 0, batch) 3503*da0073e9SAndroid Build Coastguard Worker 3504*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3505*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3506*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3507*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank_out_errors_and_warnings(self, device, dtype): 3508*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 3509*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2, dtype=dtype, device=device) 3510*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.bool, device=device) 3511*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Bool"): 3512*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_rank(a, out=out) 3513*da0073e9SAndroid Build Coastguard Worker 3514*da0073e9SAndroid Build Coastguard Worker # device should match 3515*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3516*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 3517*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=dtype, device=wrong_device) 3518*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 3519*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_rank(a, out=out) 3520*da0073e9SAndroid Build Coastguard Worker 3521*da0073e9SAndroid Build Coastguard Worker # if out tensor with wrong shape is passed a warning is given 3522*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3523*da0073e9SAndroid Build Coastguard Worker out = torch.empty(3, dtype=dtype, device=device) 3524*da0073e9SAndroid Build Coastguard Worker # Trigger warning 3525*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_rank(a, out=out) 3526*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 3527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 3528*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 3529*da0073e9SAndroid Build Coastguard Worker 3530*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 3531*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3532*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3533*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank_basic(self, device, dtype): 3534*da0073e9SAndroid Build Coastguard Worker matrix_rank = torch.linalg.matrix_rank 3535*da0073e9SAndroid Build Coastguard Worker 3536*da0073e9SAndroid Build Coastguard Worker a = torch.eye(10, dtype=dtype, device=device) 3537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(a).item(), 10) 3538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(a, hermitian=True).item(), 10) 3539*da0073e9SAndroid Build Coastguard Worker 3540*da0073e9SAndroid Build Coastguard Worker a[5, 5] = 0 3541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(a).item(), 9) 3542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(a, hermitian=True).item(), 9) 3543*da0073e9SAndroid Build Coastguard Worker 3544*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3545*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 3546*da0073e9SAndroid Build Coastguard Worker # This tests only the cases where torch.chain_matmul differs from torch.linalg.multi_dot which this is an "alias" for. 3547*da0073e9SAndroid Build Coastguard Worker def test_chain_matmul(self, device, dtype): 3548*da0073e9SAndroid Build Coastguard Worker # chain_matmul accepts a single input tensor while multi_dot does not 3549*da0073e9SAndroid Build Coastguard Worker t = make_tensor((2, 2), dtype=dtype, device=device) 3550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.chain_matmul(t)) 3551*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"): 3552*da0073e9SAndroid Build Coastguard Worker torch.chain_matmul() 3553*da0073e9SAndroid Build Coastguard Worker 3554*da0073e9SAndroid Build Coastguard Worker # chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to 3555*da0073e9SAndroid Build Coastguard Worker # be either 1D or 2D 3556*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"): 3557*da0073e9SAndroid Build Coastguard Worker torch.chain_matmul(make_tensor(1, dtype=dtype, device=device), make_tensor(1, dtype=dtype, device=device)) 3558*da0073e9SAndroid Build Coastguard Worker 3559*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3560*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3561*da0073e9SAndroid Build Coastguard Worker def test_multi_dot(self, device, dtype): 3562*da0073e9SAndroid Build Coastguard Worker def check(*shapes): 3563*da0073e9SAndroid Build Coastguard Worker tensors = [make_tensor(shape, dtype=dtype, device=device) for shape in shapes] 3564*da0073e9SAndroid Build Coastguard Worker np_arrays = [tensor.cpu().numpy() for tensor in tensors] 3565*da0073e9SAndroid Build Coastguard Worker res = torch.linalg.multi_dot(tensors).cpu() 3566*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays))) 3567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, ref) 3568*da0073e9SAndroid Build Coastguard Worker 3569*da0073e9SAndroid Build Coastguard Worker # test for inputs with empty dimensions 3570*da0073e9SAndroid Build Coastguard Worker check([0], [0]) 3571*da0073e9SAndroid Build Coastguard Worker check([2], [2, 0]) 3572*da0073e9SAndroid Build Coastguard Worker check([1, 0], [0]) 3573*da0073e9SAndroid Build Coastguard Worker check([0, 2], [2, 1]) 3574*da0073e9SAndroid Build Coastguard Worker check([2, 2], [2, 0]) 3575*da0073e9SAndroid Build Coastguard Worker check([2, 0], [0, 3]) 3576*da0073e9SAndroid Build Coastguard Worker check([0, 0], [0, 1]) 3577*da0073e9SAndroid Build Coastguard Worker check([4, 2], [2, 0], [0, 3], [3, 2]) 3578*da0073e9SAndroid Build Coastguard Worker 3579*da0073e9SAndroid Build Coastguard Worker # test variable output shapes 3580*da0073e9SAndroid Build Coastguard Worker check([2], [2]) 3581*da0073e9SAndroid Build Coastguard Worker check([1, 2], [2]) 3582*da0073e9SAndroid Build Coastguard Worker check([2], [2, 1]) 3583*da0073e9SAndroid Build Coastguard Worker check([1, 2], [2, 1]) 3584*da0073e9SAndroid Build Coastguard Worker check([3, 2], [2, 4]) 3585*da0073e9SAndroid Build Coastguard Worker 3586*da0073e9SAndroid Build Coastguard Worker # test multiple input tensors 3587*da0073e9SAndroid Build Coastguard Worker check([3], [3, 4], [4, 2], [2, 5], [5]) 3588*da0073e9SAndroid Build Coastguard Worker check([1, 2], [2, 2], [2, 3], [3, 1]) 3589*da0073e9SAndroid Build Coastguard Worker 3590*da0073e9SAndroid Build Coastguard Worker # test large tensors 3591*da0073e9SAndroid Build Coastguard Worker check([10, 100], [100, 5], [5, 50]) 3592*da0073e9SAndroid Build Coastguard Worker check([10, 20], [20, 30], [30, 5]) 3593*da0073e9SAndroid Build Coastguard Worker 3594*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3595*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 3596*da0073e9SAndroid Build Coastguard Worker def test_multi_dot_errors(self, device, dtype): 3597*da0073e9SAndroid Build Coastguard Worker def check(tensors, out, msg): 3598*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 3599*da0073e9SAndroid Build Coastguard Worker torch.linalg.multi_dot(tensors, out=out) 3600*da0073e9SAndroid Build Coastguard Worker 3601*da0073e9SAndroid Build Coastguard Worker a = make_tensor(2, dtype=dtype, device=device) 3602*da0073e9SAndroid Build Coastguard Worker 3603*da0073e9SAndroid Build Coastguard Worker check([], None, "expected at least 2 tensors") 3604*da0073e9SAndroid Build Coastguard Worker check([a], None, "expected at least 2 tensors") 3605*da0073e9SAndroid Build Coastguard Worker 3606*da0073e9SAndroid Build Coastguard Worker check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D") 3607*da0073e9SAndroid Build Coastguard Worker check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D") 3608*da0073e9SAndroid Build Coastguard Worker 3609*da0073e9SAndroid Build Coastguard Worker check([a, a, a], None, "tensor 1 must be 2D") 3610*da0073e9SAndroid Build Coastguard Worker check([a, make_tensor((2, 2, 2), dtype=dtype, device=device), a], None, "tensor 1 must be 2D") 3611*da0073e9SAndroid Build Coastguard Worker 3612*da0073e9SAndroid Build Coastguard Worker check([a, make_tensor(2, dtype=torch.double, device=device)], None, "all tensors must have be the same dtype") 3613*da0073e9SAndroid Build Coastguard Worker check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype") 3614*da0073e9SAndroid Build Coastguard Worker 3615*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 3616*da0073e9SAndroid Build Coastguard Worker check([a, make_tensor(2, dtype=dtype, device="cpu")], None, "all tensors must be on the same device") 3617*da0073e9SAndroid Build Coastguard Worker check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device") 3618*da0073e9SAndroid Build Coastguard Worker 3619*da0073e9SAndroid Build Coastguard Worker check([a, make_tensor(3, dtype=dtype, device=device)], None, "cannot be multiplied") 3620*da0073e9SAndroid Build Coastguard Worker check([a, make_tensor((3, 2), dtype=dtype, device=device), a], None, "cannot be multiplied") 3621*da0073e9SAndroid Build Coastguard Worker 3622*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6}) 3623*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 3624*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3625*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3626*da0073e9SAndroid Build Coastguard Worker def test_qr(self, device, dtype): 3627*da0073e9SAndroid Build Coastguard Worker def run_test(tensor_dims, some): 3628*da0073e9SAndroid Build Coastguard Worker A = torch.randn(*tensor_dims, dtype=dtype, device=device) 3629*da0073e9SAndroid Build Coastguard Worker Q, R = torch.qr(A, some=some) 3630*da0073e9SAndroid Build Coastguard Worker 3631*da0073e9SAndroid Build Coastguard Worker # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n) 3632*da0073e9SAndroid Build Coastguard Worker m, n = tensor_dims[-2:] 3633*da0073e9SAndroid Build Coastguard Worker n_columns = m if (not some) and m > n else min(m, n) 3634*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Q.size(-2), m) 3635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(R.size(-1), n) 3636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Q.size(-1), n_columns) 3637*da0073e9SAndroid Build Coastguard Worker 3638*da0073e9SAndroid Build Coastguard Worker A_ = A.cpu().numpy() 3639*da0073e9SAndroid Build Coastguard Worker Q_ = Q.cpu().numpy() 3640*da0073e9SAndroid Build Coastguard Worker R_ = R.cpu().numpy() 3641*da0073e9SAndroid Build Coastguard Worker 3642*da0073e9SAndroid Build Coastguard Worker # Check1: A = QR 3643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A_, np.matmul(Q_, R_)) 3644*da0073e9SAndroid Build Coastguard Worker 3645*da0073e9SAndroid Build Coastguard Worker # Check2: A = QR (with out) 3646*da0073e9SAndroid Build Coastguard Worker Q_out, R_out = torch.full_like(Q, math.nan), torch.full_like(R, math.nan) 3647*da0073e9SAndroid Build Coastguard Worker torch.qr(A, some=some, out=(Q_out, R_out)) 3648*da0073e9SAndroid Build Coastguard Worker Q_out_ = Q_out.cpu().numpy() 3649*da0073e9SAndroid Build Coastguard Worker R_out_ = R_out.cpu().numpy() 3650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A_, np.matmul(Q_out_, R_out_)) 3651*da0073e9SAndroid Build Coastguard Worker 3652*da0073e9SAndroid Build Coastguard Worker # Check3: Q == Q_out, R == R_out 3653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Q_, Q_out_) 3654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(R_, R_out_) 3655*da0073e9SAndroid Build Coastguard Worker 3656*da0073e9SAndroid Build Coastguard Worker # Check4: Q^{T}Q = I, triu(R) = R 3657*da0073e9SAndroid Build Coastguard Worker eye = torch.eye(n_columns, device=device, dtype=dtype).expand(Q.shape[:-2] + (n_columns, n_columns)).cpu().numpy() 3658*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.matmul(Q_.swapaxes(-1, -2).conj(), Q_), eye) 3659*da0073e9SAndroid Build Coastguard Worker self.assertEqual(R.triu(), R) 3660*da0073e9SAndroid Build Coastguard Worker 3661*da0073e9SAndroid Build Coastguard Worker tensor_dims_list = [(0, 5), (0, 0), (5, 0), # Empty Tensors 3662*da0073e9SAndroid Build Coastguard Worker (2, 1, 0, 5), (2, 1, 0, 0), (2, 1, 5, 0), (2, 0, 5, 5), # Batched empty Tensors 3663*da0073e9SAndroid Build Coastguard Worker (3, 5), (5, 5), (5, 3), # Single matrix 3664*da0073e9SAndroid Build Coastguard Worker (7, 3, 5), (7, 5, 5), (7, 5, 3), # 3-dim Tensors 3665*da0073e9SAndroid Build Coastguard Worker (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)] # 4-dim Tensors 3666*da0073e9SAndroid Build Coastguard Worker for tensor_dims, some in itertools.product(tensor_dims_list, [True, False]): 3667*da0073e9SAndroid Build Coastguard Worker run_test(tensor_dims, some) 3668*da0073e9SAndroid Build Coastguard Worker 3669*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 3670*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3671*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3672*da0073e9SAndroid Build Coastguard Worker def test_qr_vs_numpy(self, device, dtype): 3673*da0073e9SAndroid Build Coastguard Worker """ 3674*da0073e9SAndroid Build Coastguard Worker test torch.linalg.qr vs numpy.linalg.qr 3675*da0073e9SAndroid Build Coastguard Worker """ 3676*da0073e9SAndroid Build Coastguard Worker sizes_to_test = [ 3677*da0073e9SAndroid Build Coastguard Worker (7, 5), 3678*da0073e9SAndroid Build Coastguard Worker (5, 7), 3679*da0073e9SAndroid Build Coastguard Worker (5, 0), # empty 3680*da0073e9SAndroid Build Coastguard Worker (0, 5), # empty 3681*da0073e9SAndroid Build Coastguard Worker ] 3682*da0073e9SAndroid Build Coastguard Worker for size in sizes_to_test: 3683*da0073e9SAndroid Build Coastguard Worker t = torch.randn(size, device=device, dtype=dtype) 3684*da0073e9SAndroid Build Coastguard Worker np_t = t.cpu().numpy() 3685*da0073e9SAndroid Build Coastguard Worker for mode in ['reduced', 'complete']: 3686*da0073e9SAndroid Build Coastguard Worker exp_q, exp_r = np.linalg.qr(np_t, mode=mode) 3687*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(t, mode=mode) 3688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q, exp_q) 3689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, exp_r) 3690*da0073e9SAndroid Build Coastguard Worker # 3691*da0073e9SAndroid Build Coastguard Worker # for mode='r' we need a special logic because numpy returns only r 3692*da0073e9SAndroid Build Coastguard Worker exp_r = np.linalg.qr(np_t, mode='r') 3693*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(t, mode='r') 3694*da0073e9SAndroid Build Coastguard Worker # check that q is empty 3695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.shape, (0,)) 3696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.dtype, t.dtype) 3697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.device, t.device) 3698*da0073e9SAndroid Build Coastguard Worker # check r 3699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, exp_r) 3700*da0073e9SAndroid Build Coastguard Worker 3701*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 3702*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3703*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 3704*da0073e9SAndroid Build Coastguard Worker def test_linalg_qr_autograd_errors(self, device, dtype): 3705*da0073e9SAndroid Build Coastguard Worker # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but 3706*da0073e9SAndroid Build Coastguard Worker # without 'q' you cannot compute the backward pass. Check that 3707*da0073e9SAndroid Build Coastguard Worker # linalg_qr_backward complains cleanly in that case. 3708*da0073e9SAndroid Build Coastguard Worker inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True) 3709*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(inp, mode='r') 3710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.shape, (0,)) # empty tensor 3711*da0073e9SAndroid Build Coastguard Worker b = torch.sum(r) 3712*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3713*da0073e9SAndroid Build Coastguard Worker "The derivative of linalg.qr depends on Q"): 3714*da0073e9SAndroid Build Coastguard Worker b.backward() 3715*da0073e9SAndroid Build Coastguard Worker inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True) 3716*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(inp, mode='complete') 3717*da0073e9SAndroid Build Coastguard Worker b = torch.sum(r) 3718*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3719*da0073e9SAndroid Build Coastguard Worker "The QR decomposition is not differentiable when mode='complete' and nrows > ncols"): 3720*da0073e9SAndroid Build Coastguard Worker b.backward() 3721*da0073e9SAndroid Build Coastguard Worker 3722*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 3723*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3724*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3725*da0073e9SAndroid Build Coastguard Worker def test_qr_batched(self, device, dtype): 3726*da0073e9SAndroid Build Coastguard Worker """ 3727*da0073e9SAndroid Build Coastguard Worker test torch.linalg.qr vs numpy.linalg.qr. We need some special logic 3728*da0073e9SAndroid Build Coastguard Worker because numpy does not support batched qr 3729*da0073e9SAndroid Build Coastguard Worker """ 3730*da0073e9SAndroid Build Coastguard Worker def np_qr_batched(a, mode): 3731*da0073e9SAndroid Build Coastguard Worker """poor's man batched version of np.linalg.qr""" 3732*da0073e9SAndroid Build Coastguard Worker all_q = [] 3733*da0073e9SAndroid Build Coastguard Worker all_r = [] 3734*da0073e9SAndroid Build Coastguard Worker for matrix in a: 3735*da0073e9SAndroid Build Coastguard Worker result = np.linalg.qr(matrix, mode=mode) 3736*da0073e9SAndroid Build Coastguard Worker if mode == 'r': 3737*da0073e9SAndroid Build Coastguard Worker all_r.append(result) 3738*da0073e9SAndroid Build Coastguard Worker else: 3739*da0073e9SAndroid Build Coastguard Worker q, r = result 3740*da0073e9SAndroid Build Coastguard Worker all_q.append(q) 3741*da0073e9SAndroid Build Coastguard Worker all_r.append(r) 3742*da0073e9SAndroid Build Coastguard Worker if mode == 'r': 3743*da0073e9SAndroid Build Coastguard Worker return np.array(all_r) 3744*da0073e9SAndroid Build Coastguard Worker else: 3745*da0073e9SAndroid Build Coastguard Worker return np.array(all_q), np.array(all_r) 3746*da0073e9SAndroid Build Coastguard Worker 3747*da0073e9SAndroid Build Coastguard Worker t = torch.randn((3, 7, 5), device=device, dtype=dtype) 3748*da0073e9SAndroid Build Coastguard Worker np_t = t.cpu().numpy() 3749*da0073e9SAndroid Build Coastguard Worker for mode in ['reduced', 'complete']: 3750*da0073e9SAndroid Build Coastguard Worker exp_q, exp_r = np_qr_batched(np_t, mode=mode) 3751*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(t, mode=mode) 3752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q, exp_q) 3753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, exp_r) 3754*da0073e9SAndroid Build Coastguard Worker # for mode='r' we need a special logic because numpy returns only r 3755*da0073e9SAndroid Build Coastguard Worker exp_r = np_qr_batched(np_t, mode='r') 3756*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(t, mode='r') 3757*da0073e9SAndroid Build Coastguard Worker # check that q is empty 3758*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.shape, (0,)) 3759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.dtype, t.dtype) 3760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.device, t.device) 3761*da0073e9SAndroid Build Coastguard Worker # check r 3762*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, exp_r) 3763*da0073e9SAndroid Build Coastguard Worker 3764*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 3765*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 3766*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 3767*da0073e9SAndroid Build Coastguard Worker def test_qr_error_cases(self, device, dtype): 3768*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(5, device=device, dtype=dtype) 3769*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'linalg.qr: The input tensor A must have at least 2 dimensions.'): 3770*da0073e9SAndroid Build Coastguard Worker torch.linalg.qr(t1) 3771*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn((5, 7), device=device, dtype=dtype) 3772*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"): 3773*da0073e9SAndroid Build Coastguard Worker torch.linalg.qr(t2, mode='hello') 3774*da0073e9SAndroid Build Coastguard Worker 3775*da0073e9SAndroid Build Coastguard Worker def _check_einsum(self, *args, np_args=None): 3776*da0073e9SAndroid Build Coastguard Worker if np_args is None: 3777*da0073e9SAndroid Build Coastguard Worker np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args] 3778*da0073e9SAndroid Build Coastguard Worker ref = np.einsum(*np_args) 3779*da0073e9SAndroid Build Coastguard Worker res = torch.einsum(*args) 3780*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3781*da0073e9SAndroid Build Coastguard Worker 3782*da0073e9SAndroid Build Coastguard Worker # Check that the other variations for opt_einsum work too 3783*da0073e9SAndroid Build Coastguard Worker if TEST_OPT_EINSUM: 3784*da0073e9SAndroid Build Coastguard Worker with opt_einsum.flags(enabled=False): 3785*da0073e9SAndroid Build Coastguard Worker res = torch.einsum(*args) 3786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3787*da0073e9SAndroid Build Coastguard Worker 3788*da0073e9SAndroid Build Coastguard Worker with opt_einsum.flags(enabled=True, strategy='greedy'): 3789*da0073e9SAndroid Build Coastguard Worker res = torch.einsum(*args) 3790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3791*da0073e9SAndroid Build Coastguard Worker 3792*da0073e9SAndroid Build Coastguard Worker with opt_einsum.flags(enabled=True, strategy='optimal'): 3793*da0073e9SAndroid Build Coastguard Worker res = torch.einsum(*args) 3794*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3795*da0073e9SAndroid Build Coastguard Worker 3796*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3797*da0073e9SAndroid Build Coastguard Worker def test_einsum(self, device, dtype): 3798*da0073e9SAndroid Build Coastguard Worker # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f 3799*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) 3800*da0073e9SAndroid Build Coastguard Worker y = make_tensor((7,), dtype=dtype, device=device) 3801*da0073e9SAndroid Build Coastguard Worker A = make_tensor((3, 5), dtype=dtype, device=device) 3802*da0073e9SAndroid Build Coastguard Worker B = make_tensor((2, 5), dtype=dtype, device=device) 3803*da0073e9SAndroid Build Coastguard Worker C = make_tensor((2, 3, 5), dtype=dtype, device=device) 3804*da0073e9SAndroid Build Coastguard Worker D = make_tensor((2, 5, 7), dtype=dtype, device=device) 3805*da0073e9SAndroid Build Coastguard Worker E = make_tensor((7, 9), dtype=dtype, device=device) 3806*da0073e9SAndroid Build Coastguard Worker F = make_tensor((2, 3, 3, 5), dtype=dtype, device=device) 3807*da0073e9SAndroid Build Coastguard Worker G = make_tensor((5, 4, 6), dtype=dtype, device=device) 3808*da0073e9SAndroid Build Coastguard Worker H = make_tensor((4, 4), dtype=dtype, device=device) 3809*da0073e9SAndroid Build Coastguard Worker I = make_tensor((2, 3, 2), dtype=dtype, device=device) 3810*da0073e9SAndroid Build Coastguard Worker 3811*da0073e9SAndroid Build Coastguard Worker # Vector operations 3812*da0073e9SAndroid Build Coastguard Worker self._check_einsum('i->', x) # sum 3813*da0073e9SAndroid Build Coastguard Worker self._check_einsum('i,i->', x, x) # dot 3814*da0073e9SAndroid Build Coastguard Worker self._check_einsum('i,i->i', x, x) # vector element-wisem mul 3815*da0073e9SAndroid Build Coastguard Worker self._check_einsum('i,j->ij', x, y) # outer 3816*da0073e9SAndroid Build Coastguard Worker 3817*da0073e9SAndroid Build Coastguard Worker # Matrix operations 3818*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ij->ji", A) # transpose 3819*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ij->j", A) # row sum 3820*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ij->i", A) # col sum 3821*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ij,ij->ij", A, A) # matrix element-wise mul 3822*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ij,j->i", A, x) # matrix vector multiplication 3823*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ij,kj->ik", A, B) # matmul 3824*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ij,ab->ijab", A, E) # matrix outer product 3825*da0073e9SAndroid Build Coastguard Worker 3826*da0073e9SAndroid Build Coastguard Worker # Tensor operations 3827*da0073e9SAndroid Build Coastguard Worker self._check_einsum("Aij,Ajk->Aik", C, D) # batch matmul 3828*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ijk,jk->i", C, A) # tensor matrix contraction 3829*da0073e9SAndroid Build Coastguard Worker self._check_einsum("aij,jk->aik", D, E) # tensor matrix contraction 3830*da0073e9SAndroid Build Coastguard Worker self._check_einsum("abCd,dFg->abCFg", F, G) # tensor tensor contraction 3831*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ijk,jk->ik", C, A) # tensor matrix contraction with double indices 3832*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ijk,jk->ij", C, A) # tensor matrix contraction with double indices 3833*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ijk,ik->j", C, B) # non contiguous 3834*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ijk,ik->jk", C, B) # non contiguous with double indices 3835*da0073e9SAndroid Build Coastguard Worker 3836*da0073e9SAndroid Build Coastguard Worker # Test diagonals 3837*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ii", H) # trace 3838*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ii->i", H) # diagonal 3839*da0073e9SAndroid Build Coastguard Worker self._check_einsum('iji->j', I) # non-contiguous trace 3840*da0073e9SAndroid Build Coastguard Worker self._check_einsum('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device)) 3841*da0073e9SAndroid Build Coastguard Worker 3842*da0073e9SAndroid Build Coastguard Worker # Test ellipsis 3843*da0073e9SAndroid Build Coastguard Worker self._check_einsum("i...->...", H) 3844*da0073e9SAndroid Build Coastguard Worker self._check_einsum("ki,...k->i...", A.t(), B) 3845*da0073e9SAndroid Build Coastguard Worker self._check_einsum("k...,jk->...", A.t(), B) 3846*da0073e9SAndroid Build Coastguard Worker self._check_einsum('...ik, ...j -> ...ij', C, x) 3847*da0073e9SAndroid Build Coastguard Worker self._check_einsum('Bik,k...j->i...j', C, make_tensor((5, 3), dtype=dtype, device=device)) 3848*da0073e9SAndroid Build Coastguard Worker self._check_einsum('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), dtype=dtype, device=device)) 3849*da0073e9SAndroid Build Coastguard Worker 3850*da0073e9SAndroid Build Coastguard Worker # torch.bilinear with noncontiguous tensors 3851*da0073e9SAndroid Build Coastguard Worker l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True) 3852*da0073e9SAndroid Build Coastguard Worker r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True) 3853*da0073e9SAndroid Build Coastguard Worker w = make_tensor((15, 10, 20), dtype=dtype, device=device) 3854*da0073e9SAndroid Build Coastguard Worker self._check_einsum("bn,anm,bm->ba", l, w, r) 3855*da0073e9SAndroid Build Coastguard Worker 3856*da0073e9SAndroid Build Coastguard Worker # with strided tensors 3857*da0073e9SAndroid Build Coastguard Worker self._check_einsum("bn,Anm,bm->bA", l[:, ::2], w[:, ::2, ::2], r[:, ::2]) 3858*da0073e9SAndroid Build Coastguard Worker 3859*da0073e9SAndroid Build Coastguard Worker # test multiple inputs 3860*da0073e9SAndroid Build Coastguard Worker self._check_einsum("...,be,b...,beg,gi,bc...->bi...", A, B, C, D, E, F) 3861*da0073e9SAndroid Build Coastguard Worker 3862*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3863*da0073e9SAndroid Build Coastguard Worker def test_einsum_sublist_format(self, device, dtype): 3864*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) 3865*da0073e9SAndroid Build Coastguard Worker y = make_tensor((7,), dtype=dtype, device=device) 3866*da0073e9SAndroid Build Coastguard Worker A = make_tensor((3, 5), dtype=dtype, device=device) 3867*da0073e9SAndroid Build Coastguard Worker B = make_tensor((2, 5), dtype=dtype, device=device) 3868*da0073e9SAndroid Build Coastguard Worker C = make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device) 3869*da0073e9SAndroid Build Coastguard Worker 3870*da0073e9SAndroid Build Coastguard Worker self._check_einsum(x, [0]) 3871*da0073e9SAndroid Build Coastguard Worker self._check_einsum(x, [0], []) 3872*da0073e9SAndroid Build Coastguard Worker self._check_einsum(x, [0], y, [1], [0, 1]) 3873*da0073e9SAndroid Build Coastguard Worker self._check_einsum(A, [0, 1], [1, 0]) 3874*da0073e9SAndroid Build Coastguard Worker self._check_einsum(A, [0, 1], x, [1], [0]) 3875*da0073e9SAndroid Build Coastguard Worker self._check_einsum(A, [0, 1], B, [2, 1]) 3876*da0073e9SAndroid Build Coastguard Worker self._check_einsum(A, [0, 1], B, [2, 1], [0, 2]) 3877*da0073e9SAndroid Build Coastguard Worker self._check_einsum(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis]) 3878*da0073e9SAndroid Build Coastguard Worker self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0]) 3879*da0073e9SAndroid Build Coastguard Worker self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis]) 3880*da0073e9SAndroid Build Coastguard Worker self._check_einsum(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis]) 3881*da0073e9SAndroid Build Coastguard Worker 3882*da0073e9SAndroid Build Coastguard Worker # torch.bilinear with noncontiguous tensors 3883*da0073e9SAndroid Build Coastguard Worker l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True) 3884*da0073e9SAndroid Build Coastguard Worker r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True) 3885*da0073e9SAndroid Build Coastguard Worker w = make_tensor((15, 10, 20), dtype=dtype, device=device) 3886*da0073e9SAndroid Build Coastguard Worker self._check_einsum(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2]) 3887*da0073e9SAndroid Build Coastguard Worker 3888*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3889*da0073e9SAndroid Build Coastguard Worker def test_einsum_random(self, device, dtype): 3890*da0073e9SAndroid Build Coastguard Worker def convert_label(label): 3891*da0073e9SAndroid Build Coastguard Worker if label == ...: 3892*da0073e9SAndroid Build Coastguard Worker return '...' 3893*da0073e9SAndroid Build Coastguard Worker elif label < 26: 3894*da0073e9SAndroid Build Coastguard Worker return chr(ord('A') + label) 3895*da0073e9SAndroid Build Coastguard Worker else: 3896*da0073e9SAndroid Build Coastguard Worker return chr(ord('a') + label - 26) 3897*da0073e9SAndroid Build Coastguard Worker 3898*da0073e9SAndroid Build Coastguard Worker def convert_sublist(sublist): 3899*da0073e9SAndroid Build Coastguard Worker return ''.join(convert_label(label) for label in sublist) 3900*da0073e9SAndroid Build Coastguard Worker 3901*da0073e9SAndroid Build Coastguard Worker def test(n=10, # how many tests to generate 3902*da0073e9SAndroid Build Coastguard Worker n_labels=5, # how many labels available 3903*da0073e9SAndroid Build Coastguard Worker min_ops=1, max_ops=4, # min and max number of operands per test 3904*da0073e9SAndroid Build Coastguard Worker min_dims=1, max_dims=3, # min and max number of dimensions per operand 3905*da0073e9SAndroid Build Coastguard Worker min_size=1, max_size=8, # min and max size of each dimension 3906*da0073e9SAndroid Build Coastguard Worker max_out_dim=3, # max number of dimensions for the output 3907*da0073e9SAndroid Build Coastguard Worker enable_diagonals=True, # controls if labels can be repeated for diagonals 3908*da0073e9SAndroid Build Coastguard Worker ellipsis_prob=0.5, # probability of including ellipsis in operand 3909*da0073e9SAndroid Build Coastguard Worker broadcasting_prob=0.1): # probability of turning some dim sizes 1 for broadcasting 3910*da0073e9SAndroid Build Coastguard Worker 3911*da0073e9SAndroid Build Coastguard Worker all_labels = torch.arange(52) 3912*da0073e9SAndroid Build Coastguard Worker 3913*da0073e9SAndroid Build Coastguard Worker assert 0 <= n 3914*da0073e9SAndroid Build Coastguard Worker assert 0 <= n_labels < len(all_labels) 3915*da0073e9SAndroid Build Coastguard Worker assert 0 < min_ops <= max_ops 3916*da0073e9SAndroid Build Coastguard Worker assert 0 <= min_dims <= max_dims 3917*da0073e9SAndroid Build Coastguard Worker assert 0 <= min_size <= max_size 3918*da0073e9SAndroid Build Coastguard Worker assert 0 <= max_out_dim 3919*da0073e9SAndroid Build Coastguard Worker assert enable_diagonals or max_dims <= n_labels 3920*da0073e9SAndroid Build Coastguard Worker 3921*da0073e9SAndroid Build Coastguard Worker for _ in range(n): 3922*da0073e9SAndroid Build Coastguard Worker 3923*da0073e9SAndroid Build Coastguard Worker # Select a subset of labels for this test and give them random sizes 3924*da0073e9SAndroid Build Coastguard Worker possible_labels = all_labels[torch.randperm(len(all_labels))[:n_labels]] 3925*da0073e9SAndroid Build Coastguard Worker labels_size = torch.randint_like(all_labels, min_size, max_size + 1) 3926*da0073e9SAndroid Build Coastguard Worker ellipsis_shape = torch.randint(min_size, max_size + 1, (max_dims - min_dims,)) 3927*da0073e9SAndroid Build Coastguard Worker 3928*da0073e9SAndroid Build Coastguard Worker operands = [] 3929*da0073e9SAndroid Build Coastguard Worker sublists = [] 3930*da0073e9SAndroid Build Coastguard Worker 3931*da0073e9SAndroid Build Coastguard Worker ell_size = 0 3932*da0073e9SAndroid Build Coastguard Worker valid_labels = set() 3933*da0073e9SAndroid Build Coastguard Worker 3934*da0073e9SAndroid Build Coastguard Worker # create random input operands 3935*da0073e9SAndroid Build Coastguard Worker for _ in range(random.randint(min_ops, max_ops)): 3936*da0073e9SAndroid Build Coastguard Worker n_dim = random.randint(min_dims, max_dims) 3937*da0073e9SAndroid Build Coastguard Worker labels_idx = torch.ones(len(possible_labels)).multinomial(n_dim, enable_diagonals) 3938*da0073e9SAndroid Build Coastguard Worker labels = possible_labels[labels_idx] 3939*da0073e9SAndroid Build Coastguard Worker valid_labels.update(labels.tolist()) 3940*da0073e9SAndroid Build Coastguard Worker shape = labels_size[labels] 3941*da0073e9SAndroid Build Coastguard Worker 3942*da0073e9SAndroid Build Coastguard Worker # turn some dimensions to size 1 for testing broadcasting 3943*da0073e9SAndroid Build Coastguard Worker mask = Binomial(probs=broadcasting_prob).sample((n_dim,)) 3944*da0073e9SAndroid Build Coastguard Worker broadcast_labels = torch.unique(labels[mask == 1]) 3945*da0073e9SAndroid Build Coastguard Worker shape[(labels[..., None] == broadcast_labels).any(-1)] = 1 3946*da0073e9SAndroid Build Coastguard Worker 3947*da0073e9SAndroid Build Coastguard Worker labels = labels.tolist() 3948*da0073e9SAndroid Build Coastguard Worker shape = shape.tolist() 3949*da0073e9SAndroid Build Coastguard Worker 3950*da0073e9SAndroid Build Coastguard Worker # include ellipsis if not all dimensions were assigned a label already 3951*da0073e9SAndroid Build Coastguard Worker if n_dim < max_dims and torch.rand(1) < ellipsis_prob: 3952*da0073e9SAndroid Build Coastguard Worker ell_num_dim = random.randint(1, max_dims - n_dim) 3953*da0073e9SAndroid Build Coastguard Worker ell_size = max(ell_size, ell_num_dim) 3954*da0073e9SAndroid Build Coastguard Worker ell_shape = ellipsis_shape[-ell_num_dim:] 3955*da0073e9SAndroid Build Coastguard Worker # again, turn some dimensions to size 1 for broadcasting 3956*da0073e9SAndroid Build Coastguard Worker mask = Binomial(probs=broadcasting_prob).sample((ell_num_dim,)) 3957*da0073e9SAndroid Build Coastguard Worker ell_shape[mask == 1] = 1 3958*da0073e9SAndroid Build Coastguard Worker ell_index = random.randint(0, n_dim) 3959*da0073e9SAndroid Build Coastguard Worker shape[ell_index:ell_index] = ell_shape 3960*da0073e9SAndroid Build Coastguard Worker labels.insert(ell_index, ...) 3961*da0073e9SAndroid Build Coastguard Worker 3962*da0073e9SAndroid Build Coastguard Worker operands.append(make_tensor(shape, dtype=dtype, device=device)) 3963*da0073e9SAndroid Build Coastguard Worker sublists.append(labels) 3964*da0073e9SAndroid Build Coastguard Worker 3965*da0073e9SAndroid Build Coastguard Worker # NumPy has a bug with the sublist format so for now we compare PyTorch sublist 3966*da0073e9SAndroid Build Coastguard Worker # implementation against the equation format implementation of NumPy 3967*da0073e9SAndroid Build Coastguard Worker # see https://github.com/numpy/numpy/issues/10926 3968*da0073e9SAndroid Build Coastguard Worker np_operands = [op.cpu().numpy() for op in operands] 3969*da0073e9SAndroid Build Coastguard Worker 3970*da0073e9SAndroid Build Coastguard Worker # test equation format 3971*da0073e9SAndroid Build Coastguard Worker equation = ','.join(convert_sublist(l) for l in sublists) 3972*da0073e9SAndroid Build Coastguard Worker self._check_einsum(equation, *operands, np_args=(equation, *np_operands)) 3973*da0073e9SAndroid Build Coastguard Worker 3974*da0073e9SAndroid Build Coastguard Worker # test sublist format 3975*da0073e9SAndroid Build Coastguard Worker args = list(itertools.chain.from_iterable(zip(operands, sublists))) 3976*da0073e9SAndroid Build Coastguard Worker self._check_einsum(*args, np_args=(equation, *np_operands)) 3977*da0073e9SAndroid Build Coastguard Worker 3978*da0073e9SAndroid Build Coastguard Worker # generate an explicit output 3979*da0073e9SAndroid Build Coastguard Worker out_sublist = [] 3980*da0073e9SAndroid Build Coastguard Worker num_out_labels = max(0, random.randint(0, min(max_out_dim, len(valid_labels))) - ell_size) 3981*da0073e9SAndroid Build Coastguard Worker if num_out_labels > 0: 3982*da0073e9SAndroid Build Coastguard Worker out_labels_idx = torch.ones(len(valid_labels)).multinomial(num_out_labels) 3983*da0073e9SAndroid Build Coastguard Worker out_sublist = torch.tensor(list(valid_labels))[out_labels_idx].tolist() 3984*da0073e9SAndroid Build Coastguard Worker out_sublist.insert(random.randint(0, num_out_labels), ...) 3985*da0073e9SAndroid Build Coastguard Worker 3986*da0073e9SAndroid Build Coastguard Worker # test equation format with explicit output 3987*da0073e9SAndroid Build Coastguard Worker equation += '->' + convert_sublist(out_sublist) 3988*da0073e9SAndroid Build Coastguard Worker self._check_einsum(equation, *operands, np_args=(equation, *np_operands)) 3989*da0073e9SAndroid Build Coastguard Worker 3990*da0073e9SAndroid Build Coastguard Worker # test sublist format with explicit output 3991*da0073e9SAndroid Build Coastguard Worker args.append(out_sublist) 3992*da0073e9SAndroid Build Coastguard Worker self._check_einsum(*args, np_args=(equation, *np_operands)) 3993*da0073e9SAndroid Build Coastguard Worker 3994*da0073e9SAndroid Build Coastguard Worker test(500) 3995*da0073e9SAndroid Build Coastguard Worker 3996*da0073e9SAndroid Build Coastguard Worker def test_einsum_corner_cases(self, device): 3997*da0073e9SAndroid Build Coastguard Worker def check(equation, *operands, expected_output): 3998*da0073e9SAndroid Build Coastguard Worker tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple) 3999*da0073e9SAndroid Build Coastguard Worker else make_tensor(operand, dtype=torch.float32, device=device) for operand in operands] 4000*da0073e9SAndroid Build Coastguard Worker output = torch.einsum(equation, tensors) 4001*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) 4002*da0073e9SAndroid Build Coastguard Worker 4003*da0073e9SAndroid Build Coastguard Worker # Test equation variantions 4004*da0073e9SAndroid Build Coastguard Worker check(' ', 1, expected_output=1) 4005*da0073e9SAndroid Build Coastguard Worker check(' -> ', 1, expected_output=1) 4006*da0073e9SAndroid Build Coastguard Worker check(' , ', 2, 2, expected_output=4) 4007*da0073e9SAndroid Build Coastguard Worker check(' , , ', 2, 2, 2, expected_output=8) 4008*da0073e9SAndroid Build Coastguard Worker check(' , -> ', 2, 2, expected_output=4) 4009*da0073e9SAndroid Build Coastguard Worker check(' i ', [1], expected_output=[1]) 4010*da0073e9SAndroid Build Coastguard Worker check(' i -> ', [1], expected_output=1) 4011*da0073e9SAndroid Build Coastguard Worker check(' i -> i ', [1], expected_output=[1]) 4012*da0073e9SAndroid Build Coastguard Worker check(' i , i ', [2], [2], expected_output=4) 4013*da0073e9SAndroid Build Coastguard Worker check(' i , i -> i ', [2], [2], expected_output=[4]) 4014*da0073e9SAndroid Build Coastguard Worker 4015*da0073e9SAndroid Build Coastguard Worker # Test tensors with 0 size dimensions 4016*da0073e9SAndroid Build Coastguard Worker check('i', [], expected_output=[]) 4017*da0073e9SAndroid Build Coastguard Worker check(' i j -> j', [[], []], expected_output=[]) 4018*da0073e9SAndroid Build Coastguard Worker check('ij->i', [[], []], expected_output=[0., 0.]) 4019*da0073e9SAndroid Build Coastguard Worker check(' i j k , k -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []]) 4020*da0073e9SAndroid Build Coastguard Worker 4021*da0073e9SAndroid Build Coastguard Worker # Test broadcasting 4022*da0073e9SAndroid Build Coastguard Worker check('i,j', [2], [1, 2], expected_output=[[2, 4]]) 4023*da0073e9SAndroid Build Coastguard Worker check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]]) 4024*da0073e9SAndroid Build Coastguard Worker 4025*da0073e9SAndroid Build Coastguard Worker # Test ellipsis broadcasting 4026*da0073e9SAndroid Build Coastguard Worker check('...', 1, expected_output=1) 4027*da0073e9SAndroid Build Coastguard Worker check('...->', 1, expected_output=1) 4028*da0073e9SAndroid Build Coastguard Worker check('...->...', 1, expected_output=1) 4029*da0073e9SAndroid Build Coastguard Worker check('...', [1], expected_output=[1]) 4030*da0073e9SAndroid Build Coastguard Worker check('...->', [1], expected_output=1) 4031*da0073e9SAndroid Build Coastguard Worker check('z...->z', [1], expected_output=[1]) 4032*da0073e9SAndroid Build Coastguard Worker check('Z...->...Z', [1], expected_output=[1]) 4033*da0073e9SAndroid Build Coastguard Worker check('...a->', [[2], [4]], expected_output=6) 4034*da0073e9SAndroid Build Coastguard Worker check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]]) 4035*da0073e9SAndroid Build Coastguard Worker 4036*da0073e9SAndroid Build Coastguard Worker def test_einsum_error_cases(self, device): 4037*da0073e9SAndroid Build Coastguard Worker def check(*args, regex, exception=RuntimeError): 4038*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex): 4039*da0073e9SAndroid Build Coastguard Worker torch.einsum(*args) 4040*da0073e9SAndroid Build Coastguard Worker 4041*da0073e9SAndroid Build Coastguard Worker x = make_tensor((2,), dtype=torch.float32, device=device) 4042*da0073e9SAndroid Build Coastguard Worker y = make_tensor((2, 3), dtype=torch.float32, device=device) 4043*da0073e9SAndroid Build Coastguard Worker 4044*da0073e9SAndroid Build Coastguard Worker check('', [], regex=r'at least one operand', exception=ValueError) 4045*da0073e9SAndroid Build Coastguard Worker check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis') 4046*da0073e9SAndroid Build Coastguard Worker check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found') 4047*da0073e9SAndroid Build Coastguard Worker check('1', [x], regex=r'invalid subscript given at index 0') 4048*da0073e9SAndroid Build Coastguard Worker check(',', [x], regex=r'fewer operands were provided than specified in the equation') 4049*da0073e9SAndroid Build Coastguard Worker check('', [x, x], regex=r'more operands were provided than specified in the equation') 4050*da0073e9SAndroid Build Coastguard Worker check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number ' 4051*da0073e9SAndroid Build Coastguard Worker r'of dimensions \(1\) for operand 0 and no ellipsis was given') 4052*da0073e9SAndroid Build Coastguard Worker check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number ' 4053*da0073e9SAndroid Build Coastguard Worker r'of dimensions \(1\) for operand 0 and no ellipsis was given') 4054*da0073e9SAndroid Build Coastguard Worker check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number ' 4055*da0073e9SAndroid Build Coastguard Worker r'of dimensions \(1\) for operand 0') 4056*da0073e9SAndroid Build Coastguard Worker check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found') 4057*da0073e9SAndroid Build Coastguard Worker check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)') 4058*da0073e9SAndroid Build Coastguard Worker check('a->1', [x], regex=r'invalid subscript given at index 3') 4059*da0073e9SAndroid Build Coastguard Worker check('a->aa', [x], regex=r'output subscript a appears more than once in the output') 4060*da0073e9SAndroid Build Coastguard Worker check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand') 4061*da0073e9SAndroid Build Coastguard Worker check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') 4062*da0073e9SAndroid Build Coastguard Worker check('...,...', [x, y], regex=r'does not broadcast') 4063*da0073e9SAndroid Build Coastguard Worker check('a,a', [x, make_tensor((3,), dtype=torch.float32, device=device)], regex=r'does not broadcast') 4064*da0073e9SAndroid Build Coastguard Worker check('a, ba', [x, y], regex=r'subscript a has size 3 for operand 1 which does not broadcast with previously' 4065*da0073e9SAndroid Build Coastguard Worker r' seen size 2') 4066*da0073e9SAndroid Build Coastguard Worker 4067*da0073e9SAndroid Build Coastguard Worker check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError) 4068*da0073e9SAndroid Build Coastguard Worker check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError) 4069*da0073e9SAndroid Build Coastguard Worker 4070*da0073e9SAndroid Build Coastguard Worker def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False): 4071*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 4072*da0073e9SAndroid Build Coastguard Worker make_fullrank = partial(make_fullrank_matrices_with_distinct_singular_values, dtype=dtype, device=device) 4073*da0073e9SAndroid Build Coastguard Worker b, n, k = shape 4074*da0073e9SAndroid Build Coastguard Worker for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8): 4075*da0073e9SAndroid Build Coastguard Worker # expand means that we generate a batch of matrices with a stride of zero in the batch dimension 4076*da0073e9SAndroid Build Coastguard Worker if (conj_a or conj_b) and not dtype.is_complex: 4077*da0073e9SAndroid Build Coastguard Worker continue 4078*da0073e9SAndroid Build Coastguard Worker # We just expand on the batch size 4079*da0073e9SAndroid Build Coastguard Worker if (expand_a or expand_b) and b == 1: 4080*da0073e9SAndroid Build Coastguard Worker continue 4081*da0073e9SAndroid Build Coastguard Worker 4082*da0073e9SAndroid Build Coastguard Worker size_a = (b, n, n) if left else (b, k, k) 4083*da0073e9SAndroid Build Coastguard Worker size_b = (b, n, k) if not tr_b else (b, k, n) 4084*da0073e9SAndroid Build Coastguard Worker 4085*da0073e9SAndroid Build Coastguard Worker # If expand_a or expand_b, we'll expand them to the correct size later 4086*da0073e9SAndroid Build Coastguard Worker if b == 1 or expand_a: 4087*da0073e9SAndroid Build Coastguard Worker size_a = size_a[1:] 4088*da0073e9SAndroid Build Coastguard Worker if b == 1 or expand_b: 4089*da0073e9SAndroid Build Coastguard Worker size_b = size_b[1:] 4090*da0073e9SAndroid Build Coastguard Worker 4091*da0073e9SAndroid Build Coastguard Worker if well_conditioned: 4092*da0073e9SAndroid Build Coastguard Worker PLU = torch.linalg.lu(make_fullrank(*size_a)) 4093*da0073e9SAndroid Build Coastguard Worker if uni: 4094*da0073e9SAndroid Build Coastguard Worker # A = L from PLU 4095*da0073e9SAndroid Build Coastguard Worker A = PLU[1].transpose(-2, -1).contiguous() 4096*da0073e9SAndroid Build Coastguard Worker else: 4097*da0073e9SAndroid Build Coastguard Worker # A = U from PLU 4098*da0073e9SAndroid Build Coastguard Worker A = PLU[2].contiguous() 4099*da0073e9SAndroid Build Coastguard Worker else: 4100*da0073e9SAndroid Build Coastguard Worker A = make_arg(size_a) 4101*da0073e9SAndroid Build Coastguard Worker A.triu_() 4102*da0073e9SAndroid Build Coastguard Worker 4103*da0073e9SAndroid Build Coastguard Worker diag = A.diagonal(0, -2, -1) 4104*da0073e9SAndroid Build Coastguard Worker if uni: 4105*da0073e9SAndroid Build Coastguard Worker diag.fill_(1.) 4106*da0073e9SAndroid Build Coastguard Worker else: 4107*da0073e9SAndroid Build Coastguard Worker diag[diag.abs() < 1e-6] = 1. 4108*da0073e9SAndroid Build Coastguard Worker 4109*da0073e9SAndroid Build Coastguard Worker B = make_arg(size_b) 4110*da0073e9SAndroid Build Coastguard Worker 4111*da0073e9SAndroid Build Coastguard Worker if tr_a: 4112*da0073e9SAndroid Build Coastguard Worker A.transpose_(-2, -1) 4113*da0073e9SAndroid Build Coastguard Worker if tr_b: 4114*da0073e9SAndroid Build Coastguard Worker B.transpose_(-2, -1) 4115*da0073e9SAndroid Build Coastguard Worker if conj_a: 4116*da0073e9SAndroid Build Coastguard Worker A = A.conj() 4117*da0073e9SAndroid Build Coastguard Worker if conj_b: 4118*da0073e9SAndroid Build Coastguard Worker B = B.conj() 4119*da0073e9SAndroid Build Coastguard Worker if expand_a: 4120*da0073e9SAndroid Build Coastguard Worker A = A.expand(b, *size_a) 4121*da0073e9SAndroid Build Coastguard Worker if expand_b: 4122*da0073e9SAndroid Build Coastguard Worker B = B.expand(b, n, k) 4123*da0073e9SAndroid Build Coastguard Worker yield A, B, left, not tr_a, uni 4124*da0073e9SAndroid Build Coastguard Worker 4125*da0073e9SAndroid Build Coastguard Worker def _test_linalg_solve_triangular(self, A, B, upper, left, uni): 4126*da0073e9SAndroid Build Coastguard Worker X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni) 4127*da0073e9SAndroid Build Coastguard Worker if left: 4128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A @ X, B) 4129*da0073e9SAndroid Build Coastguard Worker else: 4130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(X @ A, B) 4131*da0073e9SAndroid Build Coastguard Worker out = B 4132*da0073e9SAndroid Build Coastguard Worker # B may be expanded 4133*da0073e9SAndroid Build Coastguard Worker if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous(): 4134*da0073e9SAndroid Build Coastguard Worker out = B.clone() 4135*da0073e9SAndroid Build Coastguard Worker torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out) 4136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(X, out) 4137*da0073e9SAndroid Build Coastguard Worker 4138*da0073e9SAndroid Build Coastguard Worker # Tolerances dictated by widest acceptable range on CPU before failure 4139*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4140*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3 if TEST_WITH_ROCM else 1e-1, 4141*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, 4142*da0073e9SAndroid Build Coastguard Worker torch.complex64: 1e-1, 4143*da0073e9SAndroid Build Coastguard Worker torch.complex128: 1e-8}) 4144*da0073e9SAndroid Build Coastguard Worker def test_linalg_solve_triangular(self, device, dtype): 4145*da0073e9SAndroid Build Coastguard Worker # This exercises the API + BLAS CPU + batched cuBLAS 4146*da0073e9SAndroid Build Coastguard Worker ks = (3, 1, 0) 4147*da0073e9SAndroid Build Coastguard Worker ns = (5, 0) 4148*da0073e9SAndroid Build Coastguard Worker bs = (1, 2, 0) 4149*da0073e9SAndroid Build Coastguard Worker 4150*da0073e9SAndroid Build Coastguard Worker gen_inputs = self._gen_shape_inputs_linalg_triangular_solve 4151*da0073e9SAndroid Build Coastguard Worker for b, n, k in product(bs, ns, ks): 4152*da0073e9SAndroid Build Coastguard Worker for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device, well_conditioned=True): 4153*da0073e9SAndroid Build Coastguard Worker self._test_linalg_solve_triangular(A, B, upper, left, uni) 4154*da0073e9SAndroid Build Coastguard Worker 4155*da0073e9SAndroid Build Coastguard Worker @slowTest 4156*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra") 4157*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4158*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma # Magma needed for the PLU decomposition 4159*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4160*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2, 4161*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 4162*da0073e9SAndroid Build Coastguard Worker def test_linalg_solve_triangular_large(self, device, dtype): 4163*da0073e9SAndroid Build Coastguard Worker # Exercises magma and cublas 4164*da0073e9SAndroid Build Coastguard Worker magma = (9, 513, 1) 4165*da0073e9SAndroid Build Coastguard Worker iterative_cublas = (2, 64, 1) 4166*da0073e9SAndroid Build Coastguard Worker 4167*da0073e9SAndroid Build Coastguard Worker gen_inputs = self._gen_shape_inputs_linalg_triangular_solve 4168*da0073e9SAndroid Build Coastguard Worker for shape in (magma, iterative_cublas): 4169*da0073e9SAndroid Build Coastguard Worker for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True): 4170*da0073e9SAndroid Build Coastguard Worker self._test_linalg_solve_triangular(A, B, upper, left, uni) 4171*da0073e9SAndroid Build Coastguard Worker 4172*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4173*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2, 4174*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 4175*da0073e9SAndroid Build Coastguard Worker def test_linalg_solve_triangular_broadcasting(self, device, dtype): 4176*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 4177*da0073e9SAndroid Build Coastguard Worker 4178*da0073e9SAndroid Build Coastguard Worker sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)), 4179*da0073e9SAndroid Build Coastguard Worker ((2, 1, 3, 4, 4), (4, 6)), 4180*da0073e9SAndroid Build Coastguard Worker ((4, 4), (2, 1, 3, 4, 2)), 4181*da0073e9SAndroid Build Coastguard Worker ((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))) 4182*da0073e9SAndroid Build Coastguard Worker for size_A, size_B in sizes: 4183*da0073e9SAndroid Build Coastguard Worker for left, upper, uni in itertools.product([True, False], repeat=3): 4184*da0073e9SAndroid Build Coastguard Worker A = make_arg(size_A) 4185*da0073e9SAndroid Build Coastguard Worker if upper: 4186*da0073e9SAndroid Build Coastguard Worker A.triu_() 4187*da0073e9SAndroid Build Coastguard Worker else: 4188*da0073e9SAndroid Build Coastguard Worker A.tril_() 4189*da0073e9SAndroid Build Coastguard Worker diag = A.diagonal(0, -2, -1) 4190*da0073e9SAndroid Build Coastguard Worker if uni: 4191*da0073e9SAndroid Build Coastguard Worker diag.fill_(1.) 4192*da0073e9SAndroid Build Coastguard Worker else: 4193*da0073e9SAndroid Build Coastguard Worker diag[diag.abs() < 1e-6] = 1. 4194*da0073e9SAndroid Build Coastguard Worker B = make_arg(size_B) 4195*da0073e9SAndroid Build Coastguard Worker if not left: 4196*da0073e9SAndroid Build Coastguard Worker B.transpose_(-2, -1) 4197*da0073e9SAndroid Build Coastguard Worker 4198*da0073e9SAndroid Build Coastguard Worker X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni) 4199*da0073e9SAndroid Build Coastguard Worker if left: 4200*da0073e9SAndroid Build Coastguard Worker B_other = A @ X 4201*da0073e9SAndroid Build Coastguard Worker else: 4202*da0073e9SAndroid Build Coastguard Worker B_other = X @ A 4203*da0073e9SAndroid Build Coastguard Worker 4204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(*torch.broadcast_tensors(B, B_other)) 4205*da0073e9SAndroid Build Coastguard Worker 4206*da0073e9SAndroid Build Coastguard Worker def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, 4207*da0073e9SAndroid Build Coastguard Worker device, dtype): 4208*da0073e9SAndroid Build Coastguard Worker triangle_function = torch.triu if upper else torch.tril 4209*da0073e9SAndroid Build Coastguard Worker b = torch.randn(*b_dims, dtype=dtype, device=device) 4210*da0073e9SAndroid Build Coastguard Worker A = torch.randn(*A_dims, dtype=dtype, device=device) 4211*da0073e9SAndroid Build Coastguard Worker # create positive definite matrix 4212*da0073e9SAndroid Build Coastguard Worker A = torch.matmul(A, A.mT) 4213*da0073e9SAndroid Build Coastguard Worker A_triangular = triangle_function(A) 4214*da0073e9SAndroid Build Coastguard Worker if unitriangular: 4215*da0073e9SAndroid Build Coastguard Worker A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) 4216*da0073e9SAndroid Build Coastguard Worker return b, A_triangular 4217*da0073e9SAndroid Build Coastguard Worker 4218*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 4219*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 4220*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("flaky, needs investigation") 4221*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4222*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 4223*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 4224*da0073e9SAndroid Build Coastguard Worker def test_triangular_solve(self, device, dtype): 4225*da0073e9SAndroid Build Coastguard Worker ks = [0, 1, 3] 4226*da0073e9SAndroid Build Coastguard Worker ns = [0, 5] 4227*da0073e9SAndroid Build Coastguard Worker for k, n, (upper, unitriangular, transpose) in itertools.product(ks, ns, 4228*da0073e9SAndroid Build Coastguard Worker itertools.product([True, False], repeat=3)): 4229*da0073e9SAndroid Build Coastguard Worker b, A = self.triangular_solve_test_helper((n, n), (n, k), upper, 4230*da0073e9SAndroid Build Coastguard Worker unitriangular, device, dtype) 4231*da0073e9SAndroid Build Coastguard Worker x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] 4232*da0073e9SAndroid Build Coastguard Worker if transpose: 4233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, np.matmul(A.t().cpu(), x.cpu())) 4234*da0073e9SAndroid Build Coastguard Worker else: 4235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, np.matmul(A.cpu(), x.cpu())) 4236*da0073e9SAndroid Build Coastguard Worker 4237*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 4238*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 4239*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4240*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 4241*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 4242*da0073e9SAndroid Build Coastguard Worker def test_triangular_solve_batched(self, device, dtype): 4243*da0073e9SAndroid Build Coastguard Worker def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): 4244*da0073e9SAndroid Build Coastguard Worker b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, 4245*da0073e9SAndroid Build Coastguard Worker unitriangular, device, dtype) 4246*da0073e9SAndroid Build Coastguard Worker x_exp_list = [] 4247*da0073e9SAndroid Build Coastguard Worker for i in range(b_dims[0]): 4248*da0073e9SAndroid Build Coastguard Worker x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, 4249*da0073e9SAndroid Build Coastguard Worker unitriangular=unitriangular, 4250*da0073e9SAndroid Build Coastguard Worker transpose=transpose)[0]) 4251*da0073e9SAndroid Build Coastguard Worker x_exp = torch.stack(x_exp_list) # Stacked output 4252*da0073e9SAndroid Build Coastguard Worker x_act = torch.triangular_solve(b, A, upper=upper, 4253*da0073e9SAndroid Build Coastguard Worker unitriangular=unitriangular, 4254*da0073e9SAndroid Build Coastguard Worker transpose=transpose)[0] # Actual output 4255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_act, x_exp) # Equality check 4256*da0073e9SAndroid Build Coastguard Worker if transpose: 4257*da0073e9SAndroid Build Coastguard Worker A = A.mT 4258*da0073e9SAndroid Build Coastguard Worker 4259*da0073e9SAndroid Build Coastguard Worker Ax = np.matmul(A.cpu(), x_act.cpu()) 4260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, Ax) 4261*da0073e9SAndroid Build Coastguard Worker 4262*da0073e9SAndroid Build Coastguard Worker def triangular_solve_zero_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): 4263*da0073e9SAndroid Build Coastguard Worker b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, 4264*da0073e9SAndroid Build Coastguard Worker unitriangular, device, dtype) 4265*da0073e9SAndroid Build Coastguard Worker x = torch.triangular_solve(b, A, upper=upper, 4266*da0073e9SAndroid Build Coastguard Worker unitriangular=unitriangular, 4267*da0073e9SAndroid Build Coastguard Worker transpose=transpose)[0] 4268*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.shape == b.shape) 4269*da0073e9SAndroid Build Coastguard Worker 4270*da0073e9SAndroid Build Coastguard Worker for upper, unitriangular, transpose in itertools.product([True, False], repeat=3): 4271*da0073e9SAndroid Build Coastguard Worker batchsize = 3 4272*da0073e9SAndroid Build Coastguard Worker triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), 4273*da0073e9SAndroid Build Coastguard Worker upper, unitriangular, transpose) 4274*da0073e9SAndroid Build Coastguard Worker 4275*da0073e9SAndroid Build Coastguard Worker # test empty input 4276*da0073e9SAndroid Build Coastguard Worker triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 10), 4277*da0073e9SAndroid Build Coastguard Worker upper, unitriangular, transpose) 4278*da0073e9SAndroid Build Coastguard Worker triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 0), 4279*da0073e9SAndroid Build Coastguard Worker upper, unitriangular, transpose) 4280*da0073e9SAndroid Build Coastguard Worker 4281*da0073e9SAndroid Build Coastguard Worker # test zero batch case 4282*da0073e9SAndroid Build Coastguard Worker batchsize = 0 4283*da0073e9SAndroid Build Coastguard Worker triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), 4284*da0073e9SAndroid Build Coastguard Worker upper, unitriangular, transpose) 4285*da0073e9SAndroid Build Coastguard Worker 4286*da0073e9SAndroid Build Coastguard Worker 4287*da0073e9SAndroid Build Coastguard Worker @slowTest 4288*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 4289*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 4290*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4291*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 4292*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 4293*da0073e9SAndroid Build Coastguard Worker def test_triangular_solve_batched_many_batches(self, device, dtype): 4294*da0073e9SAndroid Build Coastguard Worker for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): 4295*da0073e9SAndroid Build Coastguard Worker # test batched A case 4296*da0073e9SAndroid Build Coastguard Worker b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1), 4297*da0073e9SAndroid Build Coastguard Worker upper, unitriangular, device, dtype) 4298*da0073e9SAndroid Build Coastguard Worker x, _ = torch.triangular_solve(b, A, 4299*da0073e9SAndroid Build Coastguard Worker upper=upper, transpose=transpose, unitriangular=unitriangular) 4300*da0073e9SAndroid Build Coastguard Worker if transpose: 4301*da0073e9SAndroid Build Coastguard Worker A = A.mT 4302*da0073e9SAndroid Build Coastguard Worker 4303*da0073e9SAndroid Build Coastguard Worker Ax = torch.matmul(A, x) 4304*da0073e9SAndroid Build Coastguard Worker 4305*da0073e9SAndroid Build Coastguard Worker rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision 4306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol) 4307*da0073e9SAndroid Build Coastguard Worker 4308*da0073e9SAndroid Build Coastguard Worker # test batched b case 4309*da0073e9SAndroid Build Coastguard Worker b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1), 4310*da0073e9SAndroid Build Coastguard Worker upper, unitriangular, device, dtype) 4311*da0073e9SAndroid Build Coastguard Worker x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, 4312*da0073e9SAndroid Build Coastguard Worker unitriangular=unitriangular) 4313*da0073e9SAndroid Build Coastguard Worker if transpose: 4314*da0073e9SAndroid Build Coastguard Worker A = A.mT 4315*da0073e9SAndroid Build Coastguard Worker 4316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(A, x), b) 4317*da0073e9SAndroid Build Coastguard Worker 4318*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 4319*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 4320*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 4321*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("flaky, needs investigation") 4322*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4323*da0073e9SAndroid Build Coastguard Worker def test_triangular_solve_batched_broadcasting(self, device, dtype): 4324*da0073e9SAndroid Build Coastguard Worker from scipy.linalg import solve_triangular as tri_solve 4325*da0073e9SAndroid Build Coastguard Worker 4326*da0073e9SAndroid Build Coastguard Worker def scipy_tri_solve_batched(A, B, upper, trans, diag): 4327*da0073e9SAndroid Build Coastguard Worker batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] 4328*da0073e9SAndroid Build Coastguard Worker single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] 4329*da0073e9SAndroid Build Coastguard Worker expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), 4330*da0073e9SAndroid Build Coastguard Worker torch.Size(batch_dims_B))) 4331*da0073e9SAndroid Build Coastguard Worker expand_A = np.broadcast_to(A, expand_dims + single_dim_A) 4332*da0073e9SAndroid Build Coastguard Worker expand_B = np.broadcast_to(B, expand_dims + single_dim_B) 4333*da0073e9SAndroid Build Coastguard Worker flat_A = expand_A.reshape((-1,) + single_dim_A) 4334*da0073e9SAndroid Build Coastguard Worker flat_B = expand_B.reshape((-1,) + single_dim_B) 4335*da0073e9SAndroid Build Coastguard Worker flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) 4336*da0073e9SAndroid Build Coastguard Worker for a, b in zip(flat_A, flat_B)]) 4337*da0073e9SAndroid Build Coastguard Worker return flat_X.reshape(expand_B.shape) 4338*da0073e9SAndroid Build Coastguard Worker 4339*da0073e9SAndroid Build Coastguard Worker def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): 4340*da0073e9SAndroid Build Coastguard Worker b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, 4341*da0073e9SAndroid Build Coastguard Worker unitriangular, device, dtype) 4342*da0073e9SAndroid Build Coastguard Worker x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), 4343*da0073e9SAndroid Build Coastguard Worker upper, transpose, unitriangular)) 4344*da0073e9SAndroid Build Coastguard Worker x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] 4345*da0073e9SAndroid Build Coastguard Worker 4346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_exp.to(device)) 4347*da0073e9SAndroid Build Coastguard Worker 4348*da0073e9SAndroid Build Coastguard Worker for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): 4349*da0073e9SAndroid Build Coastguard Worker # test against scipy.linalg.solve_triangular 4350*da0073e9SAndroid Build Coastguard Worker run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular) # no broadcasting 4351*da0073e9SAndroid Build Coastguard Worker run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b 4352*da0073e9SAndroid Build Coastguard Worker run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A 4353*da0073e9SAndroid Build Coastguard Worker run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b 4354*da0073e9SAndroid Build Coastguard Worker 4355*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4356*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 4357*da0073e9SAndroid Build Coastguard Worker def test_triangular_solve_large(self, device, dtype): 4358*da0073e9SAndroid Build Coastguard Worker # Repro for https://github.com/pytorch/pytorch/issues/79191 4359*da0073e9SAndroid Build Coastguard Worker A = torch.randn(1, 2, 2, device=device, dtype=dtype).tril_() 4360*da0073e9SAndroid Build Coastguard Worker B = torch.randn(1, 2, 524281, device=device, dtype=dtype) 4361*da0073e9SAndroid Build Coastguard Worker X = torch.linalg.solve_triangular(A, B, upper=False) 4362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A @ X, B) 4363*da0073e9SAndroid Build Coastguard Worker 4364*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 4365*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 4366*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 4367*da0073e9SAndroid Build Coastguard Worker def test_triangular_solve_out_errors_and_warnings(self, device, dtype): 4368*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 4369*da0073e9SAndroid Build Coastguard Worker a = torch.eye(2, dtype=dtype, device=device) 4370*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 1, dtype=dtype, device=device) 4371*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(b).to(torch.int) 4372*da0073e9SAndroid Build Coastguard Worker clone_a = torch.empty_like(a) 4373*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"): 4374*da0073e9SAndroid Build Coastguard Worker torch.triangular_solve(b, a, out=(out, clone_a)) 4375*da0073e9SAndroid Build Coastguard Worker 4376*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(b) 4377*da0073e9SAndroid Build Coastguard Worker clone_a = clone_a.to(torch.int) 4378*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"): 4379*da0073e9SAndroid Build Coastguard Worker torch.triangular_solve(b, a, out=(out, clone_a)) 4380*da0073e9SAndroid Build Coastguard Worker 4381*da0073e9SAndroid Build Coastguard Worker # device should match 4382*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 4383*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 4384*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=dtype, device=wrong_device) 4385*da0073e9SAndroid Build Coastguard Worker clone_a = torch.empty_like(a) 4386*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 4387*da0073e9SAndroid Build Coastguard Worker torch.triangular_solve(b, a, out=(out, clone_a)) 4388*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=dtype, device=device) 4389*da0073e9SAndroid Build Coastguard Worker clone_a = torch.empty_like(a).to(wrong_device) 4390*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 4391*da0073e9SAndroid Build Coastguard Worker torch.triangular_solve(b, a, out=(out, clone_a)) 4392*da0073e9SAndroid Build Coastguard Worker 4393*da0073e9SAndroid Build Coastguard Worker # Trigger the WARN_ONCE deprecation error 4394*da0073e9SAndroid Build Coastguard Worker torch.triangular_solve(b, a) 4395*da0073e9SAndroid Build Coastguard Worker 4396*da0073e9SAndroid Build Coastguard Worker # if out tensor with wrong shape is passed a warning is given 4397*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 4398*da0073e9SAndroid Build Coastguard Worker out = torch.empty(1, dtype=dtype, device=device) 4399*da0073e9SAndroid Build Coastguard Worker clone_a = torch.empty(1, dtype=dtype, device=device) 4400*da0073e9SAndroid Build Coastguard Worker # Trigger warning 4401*da0073e9SAndroid Build Coastguard Worker torch.triangular_solve(b, a, out=(out, clone_a)) 4402*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 4403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 4404*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[0].message)) 4405*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[1].message)) 4406*da0073e9SAndroid Build Coastguard Worker 4407*da0073e9SAndroid Build Coastguard Worker 4408*da0073e9SAndroid Build Coastguard Worker def check_single_matmul(self, x, y): 4409*da0073e9SAndroid Build Coastguard Worker 4410*da0073e9SAndroid Build Coastguard Worker def assertEqual(answer, expected): 4411*da0073e9SAndroid Build Coastguard Worker if x.dtype.is_floating_point or x.dtype.is_complex: 4412*da0073e9SAndroid Build Coastguard Worker k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix 4413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(answer, expected, 4414*da0073e9SAndroid Build Coastguard Worker msg=f"{x.shape} x {y.shape} = {answer.shape}", 4415*da0073e9SAndroid Build Coastguard Worker atol=k * 5e-5, 4416*da0073e9SAndroid Build Coastguard Worker rtol=1e-4) 4417*da0073e9SAndroid Build Coastguard Worker else: 4418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}") 4419*da0073e9SAndroid Build Coastguard Worker 4420*da0073e9SAndroid Build Coastguard Worker # test x @ y 4421*da0073e9SAndroid Build Coastguard Worker expected = np.matmul(x.cpu(), y.cpu()) 4422*da0073e9SAndroid Build Coastguard Worker ans = torch.matmul(x, y) 4423*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ans.is_contiguous()) 4424*da0073e9SAndroid Build Coastguard Worker assertEqual(ans, expected) 4425*da0073e9SAndroid Build Coastguard Worker 4426*da0073e9SAndroid Build Coastguard Worker # test out 4427*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(ans) 4428*da0073e9SAndroid Build Coastguard Worker ans = torch.matmul(x, y, out=out) 4429*da0073e9SAndroid Build Coastguard Worker self.assertIs(ans, out) 4430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ans.is_contiguous()) 4431*da0073e9SAndroid Build Coastguard Worker assertEqual(ans, expected) 4432*da0073e9SAndroid Build Coastguard Worker 4433*da0073e9SAndroid Build Coastguard Worker def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3): 4434*da0073e9SAndroid Build Coastguard Worker """ 4435*da0073e9SAndroid Build Coastguard Worker Generates sequences of tuples (x, y) of with size(x) = x_dim and 4436*da0073e9SAndroid Build Coastguard Worker size(y) <= y_dim that are compatible wrt. matmul 4437*da0073e9SAndroid Build Coastguard Worker """ 4438*da0073e9SAndroid Build Coastguard Worker assert x_dim >= 1 4439*da0073e9SAndroid Build Coastguard Worker assert y_dim >= 2 4440*da0073e9SAndroid Build Coastguard Worker x = x_dim 4441*da0073e9SAndroid Build Coastguard Worker for y in range(1, y_dim + 1): 4442*da0073e9SAndroid Build Coastguard Worker for batch, mn in product(product(range(batch_size), repeat=max(x - 2, y - 2, 0)), 4443*da0073e9SAndroid Build Coastguard Worker product(range(matrix_size), repeat=min(y, 2))): 4444*da0073e9SAndroid Build Coastguard Worker if x == 1: 4445*da0073e9SAndroid Build Coastguard Worker size_x = mn[:1] 4446*da0073e9SAndroid Build Coastguard Worker size_y = batch + mn 4447*da0073e9SAndroid Build Coastguard Worker yield size_x, size_y 4448*da0073e9SAndroid Build Coastguard Worker else: 4449*da0073e9SAndroid Build Coastguard Worker for k in range(matrix_size): 4450*da0073e9SAndroid Build Coastguard Worker size_x = (k,) + mn[:1] 4451*da0073e9SAndroid Build Coastguard Worker if x > 2: 4452*da0073e9SAndroid Build Coastguard Worker size_x = batch[-(x - 2):] + size_x 4453*da0073e9SAndroid Build Coastguard Worker size_y = mn 4454*da0073e9SAndroid Build Coastguard Worker if y > 2: 4455*da0073e9SAndroid Build Coastguard Worker size_y = batch[-(y - 2):] + size_y 4456*da0073e9SAndroid Build Coastguard Worker yield size_x, size_y 4457*da0073e9SAndroid Build Coastguard Worker 4458*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.complex64) # Integer matmul just supported on CPU 4459*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex64) 4460*da0073e9SAndroid Build Coastguard Worker @setBlasBackendsToDefaultFinally 4461*da0073e9SAndroid Build Coastguard Worker def test_matmul_small_brute_force_1d_Nd(self, device, dtype): 4462*da0073e9SAndroid Build Coastguard Worker for backend in ["cublas", "cublaslt"]: 4463*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 4464*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_blas_library(backend) 4465*da0073e9SAndroid Build Coastguard Worker 4466*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 4467*da0073e9SAndroid Build Coastguard Worker 4468*da0073e9SAndroid Build Coastguard Worker for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)): 4469*da0073e9SAndroid Build Coastguard Worker x = make_arg(size_x, noncontiguous=nctg_x) 4470*da0073e9SAndroid Build Coastguard Worker y = make_arg(size_y, noncontiguous=nctg_y) 4471*da0073e9SAndroid Build Coastguard Worker self.check_single_matmul(x, y) 4472*da0073e9SAndroid Build Coastguard Worker 4473*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.complex64) # Integer matmul just supported on CPU 4474*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex64) 4475*da0073e9SAndroid Build Coastguard Worker @setBlasBackendsToDefaultFinally 4476*da0073e9SAndroid Build Coastguard Worker def test_matmul_small_brute_force_2d_Nd(self, device, dtype): 4477*da0073e9SAndroid Build Coastguard Worker for backend in ["cublas", "cublaslt"]: 4478*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 4479*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_blas_library(backend) 4480*da0073e9SAndroid Build Coastguard Worker 4481*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 4482*da0073e9SAndroid Build Coastguard Worker 4483*da0073e9SAndroid Build Coastguard Worker for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)): 4484*da0073e9SAndroid Build Coastguard Worker x = make_arg(size_x, noncontiguous=nctg_x) 4485*da0073e9SAndroid Build Coastguard Worker y = make_arg(size_y, noncontiguous=nctg_y) 4486*da0073e9SAndroid Build Coastguard Worker self.check_single_matmul(x, y) 4487*da0073e9SAndroid Build Coastguard Worker 4488*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.complex64) # Integer matmul just supported on CPU 4489*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex64) 4490*da0073e9SAndroid Build Coastguard Worker @setBlasBackendsToDefaultFinally 4491*da0073e9SAndroid Build Coastguard Worker def test_matmul_small_brute_force_3d_Nd(self, device, dtype): 4492*da0073e9SAndroid Build Coastguard Worker for backend in ["cublas", "cublaslt"]: 4493*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 4494*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_blas_library(backend) 4495*da0073e9SAndroid Build Coastguard Worker 4496*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 4497*da0073e9SAndroid Build Coastguard Worker 4498*da0073e9SAndroid Build Coastguard Worker for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(3), (True, False), (True, False)): 4499*da0073e9SAndroid Build Coastguard Worker x = make_arg(size_x, noncontiguous=nctg_x) 4500*da0073e9SAndroid Build Coastguard Worker y = make_arg(size_y, noncontiguous=nctg_y) 4501*da0073e9SAndroid Build Coastguard Worker self.check_single_matmul(x, y) 4502*da0073e9SAndroid Build Coastguard Worker 4503*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4504*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 4505*da0073e9SAndroid Build Coastguard Worker def test_matmul_small_brute_force_tunableop(self, device, dtype): 4506*da0073e9SAndroid Build Coastguard Worker # disable tunableop buffer rotation for all tests everywhere, it can be slow 4507*da0073e9SAndroid Build Coastguard Worker import os 4508*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0" 4509*da0073e9SAndroid Build Coastguard Worker set_tunableop_defaults() 4510*da0073e9SAndroid Build Coastguard Worker 4511*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable() 4512*da0073e9SAndroid Build Coastguard Worker # set these to single iterations to keep it short but still exercise the code 4513*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_duration(1) 4514*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(1) 4515*da0073e9SAndroid Build Coastguard Worker 4516*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 4517*da0073e9SAndroid Build Coastguard Worker 4518*da0073e9SAndroid Build Coastguard Worker for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)): 4519*da0073e9SAndroid Build Coastguard Worker x = make_arg(size_x, noncontiguous=nctg_x) 4520*da0073e9SAndroid Build Coastguard Worker y = make_arg(size_y, noncontiguous=nctg_y) 4521*da0073e9SAndroid Build Coastguard Worker self.check_single_matmul(x, y) 4522*da0073e9SAndroid Build Coastguard Worker 4523*da0073e9SAndroid Build Coastguard Worker filename1 = torch.cuda.tunable.get_filename() 4524*da0073e9SAndroid Build Coastguard Worker filename2 = "tunableop_results_tmp1.csv" 4525*da0073e9SAndroid Build Coastguard Worker filename3 = "tunableop_results_tmp2.csv" 4526*da0073e9SAndroid Build Coastguard Worker ordinal = torch.cuda.current_device() 4527*da0073e9SAndroid Build Coastguard Worker assert filename1 == f"tunableop_results{ordinal}.csv" 4528*da0073e9SAndroid Build Coastguard Worker assert len(torch.cuda.tunable.get_validators()) > 0 4529*da0073e9SAndroid Build Coastguard Worker validators = {} 4530*da0073e9SAndroid Build Coastguard Worker for key, value in torch.cuda.tunable.get_validators(): 4531*da0073e9SAndroid Build Coastguard Worker validators[key] = value 4532*da0073e9SAndroid Build Coastguard Worker if torch.version.hip: 4533*da0073e9SAndroid Build Coastguard Worker assert "HIPBLASLT_VERSION" in validators 4534*da0073e9SAndroid Build Coastguard Worker assert re.match(r'^\d{3}-[a-z0-9]{8}$', validators["HIPBLASLT_VERSION"]) 4535*da0073e9SAndroid Build Coastguard Worker assert len(torch.cuda.tunable.get_results()) > 0 4536*da0073e9SAndroid Build Coastguard Worker 4537*da0073e9SAndroid Build Coastguard Worker assert torch.cuda.tunable.write_file() # use default filename 4538*da0073e9SAndroid Build Coastguard Worker assert torch.cuda.tunable.write_file(filename2) # use custom, one-time filename 4539*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_filename(filename3) 4540*da0073e9SAndroid Build Coastguard Worker assert torch.cuda.tunable.write_file() # use previously set filename 4541*da0073e9SAndroid Build Coastguard Worker assert torch.cuda.tunable.read_file() # use previously set filename, will ignore duplicates and return True 4542*da0073e9SAndroid Build Coastguard Worker 4543*da0073e9SAndroid Build Coastguard Worker with open(filename1) as file1: 4544*da0073e9SAndroid Build Coastguard Worker file1_contents = file1.read() 4545*da0073e9SAndroid Build Coastguard Worker with open(filename2) as file2: 4546*da0073e9SAndroid Build Coastguard Worker file2_contents = file2.read() 4547*da0073e9SAndroid Build Coastguard Worker with open(filename3) as file3: 4548*da0073e9SAndroid Build Coastguard Worker file3_contents = file3.read() 4549*da0073e9SAndroid Build Coastguard Worker assert file1_contents == file2_contents 4550*da0073e9SAndroid Build Coastguard Worker assert file1_contents == file3_contents 4551*da0073e9SAndroid Build Coastguard Worker 4552*da0073e9SAndroid Build Coastguard Worker # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors 4553*da0073e9SAndroid Build Coastguard Worker for filename in [filename1, filename2, filename3]: 4554*da0073e9SAndroid Build Coastguard Worker try: 4555*da0073e9SAndroid Build Coastguard Worker import os 4556*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4557*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4558*da0073e9SAndroid Build Coastguard Worker pass 4559*da0073e9SAndroid Build Coastguard Worker 4560*da0073e9SAndroid Build Coastguard Worker # disables TunableOp 4561*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4562*da0073e9SAndroid Build Coastguard Worker 4563*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4564*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 4565*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 4566*da0073e9SAndroid Build Coastguard Worker def test_bmm_tunableop_rocm(self, device, dtype): 4567*da0073e9SAndroid Build Coastguard Worker # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault 4568*da0073e9SAndroid Build Coastguard Worker set_tunableop_defaults() 4569*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(True) 4570*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(10) 4571*da0073e9SAndroid Build Coastguard Worker # the following 3 cases cover all previous failure cases and are here to catch regressions 4572*da0073e9SAndroid Build Coastguard Worker B = 16 4573*da0073e9SAndroid Build Coastguard Worker N = M = K = 256 4574*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 4575*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 4576*da0073e9SAndroid Build Coastguard Worker # case 1 4577*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4578*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4579*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4580*da0073e9SAndroid Build Coastguard Worker # case 2 4581*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4582*da0073e9SAndroid Build Coastguard Worker i1 = torch.permute(i1, (1, 2, 0)) 4583*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4584*da0073e9SAndroid Build Coastguard Worker i2 = torch.permute(i2, (1, 0, 2)) 4585*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4586*da0073e9SAndroid Build Coastguard Worker # case 3 4587*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((N, B, M), device=device, dtype=dtype) 4588*da0073e9SAndroid Build Coastguard Worker i1 = torch.permute(i1, (1, 0, 2)) 4589*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((M, B, K), device=device, dtype=dtype) 4590*da0073e9SAndroid Build Coastguard Worker i2 = torch.permute(i2, (1, 2, 0)) 4591*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4592*da0073e9SAndroid Build Coastguard Worker # case 4 4593*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.rand((1920, 1, 100), device=device, dtype=dtype) 4594*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.as_strided( 4595*da0073e9SAndroid Build Coastguard Worker input_tensor, size=(1920, 1, 100), stride=(100, 100, 1) 4596*da0073e9SAndroid Build Coastguard Worker ) 4597*da0073e9SAndroid Build Coastguard Worker batch1_tensor = torch.rand((1920, 256, 512), device=device, dtype=dtype) 4598*da0073e9SAndroid Build Coastguard Worker batch1_tensor = torch.as_strided( 4599*da0073e9SAndroid Build Coastguard Worker batch1_tensor, size=(1920, 256, 512), stride=(512, 983040, 1) 4600*da0073e9SAndroid Build Coastguard Worker ) 4601*da0073e9SAndroid Build Coastguard Worker batch2_tensor = torch.rand((1920, 512, 100), device=device, dtype=dtype) 4602*da0073e9SAndroid Build Coastguard Worker batch2_tensor = torch.as_strided( 4603*da0073e9SAndroid Build Coastguard Worker batch2_tensor, size=(1920, 512, 100), stride=(51200, 100, 1) 4604*da0073e9SAndroid Build Coastguard Worker ) 4605*da0073e9SAndroid Build Coastguard Worker out = torch.baddbmm(input_tensor, batch1_tensor, batch2_tensor) 4606*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 4607*da0073e9SAndroid Build Coastguard Worker try: 4608*da0073e9SAndroid Build Coastguard Worker import os 4609*da0073e9SAndroid Build Coastguard Worker filename = torch.cuda.tunable.get_filename() 4610*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4611*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4612*da0073e9SAndroid Build Coastguard Worker pass 4613*da0073e9SAndroid Build Coastguard Worker 4614*da0073e9SAndroid Build Coastguard Worker # disable TunableOp 4615*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4616*da0073e9SAndroid Build Coastguard Worker 4617*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4618*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 4619*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 4620*da0073e9SAndroid Build Coastguard Worker def test_numeric_check_leak_tunableop_rocm(self, device, dtype): 4621*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import CudaMemoryLeakCheck 4622*da0073e9SAndroid Build Coastguard Worker import os 4623*da0073e9SAndroid Build Coastguard Worker # run operator first without tuning to ensure all rocm libs are loaded, 4624*da0073e9SAndroid Build Coastguard Worker # otherwise false positive mem leak 4625*da0073e9SAndroid Build Coastguard Worker B = 16 4626*da0073e9SAndroid Build Coastguard Worker N = M = K = 256 4627*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 4628*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 4629*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4630*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4631*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4632*da0073e9SAndroid Build Coastguard Worker # enable tunableop numeric check via env variable. 4633*da0073e9SAndroid Build Coastguard Worker PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK" 4634*da0073e9SAndroid Build Coastguard Worker prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK) 4635*da0073e9SAndroid Build Coastguard Worker try: 4636*da0073e9SAndroid Build Coastguard Worker os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1" 4637*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(True) 4638*da0073e9SAndroid Build Coastguard Worker ordinal = torch.cuda.current_device() 4639*da0073e9SAndroid Build Coastguard Worker filename = f"tunableop_results{ordinal}.csv" 4640*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_filename(filename) 4641*da0073e9SAndroid Build Coastguard Worker iterations = torch.cuda.tunable.get_max_tuning_iterations() 4642*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(10) 4643*da0073e9SAndroid Build Coastguard Worker with CudaMemoryLeakCheck(self): 4644*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4645*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(iterations) 4646*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4647*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 4648*da0073e9SAndroid Build Coastguard Worker try: 4649*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4650*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4651*da0073e9SAndroid Build Coastguard Worker pass 4652*da0073e9SAndroid Build Coastguard Worker finally: 4653*da0073e9SAndroid Build Coastguard Worker if prev_val is None: 4654*da0073e9SAndroid Build Coastguard Worker del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] 4655*da0073e9SAndroid Build Coastguard Worker else: 4656*da0073e9SAndroid Build Coastguard Worker os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val 4657*da0073e9SAndroid Build Coastguard Worker 4658*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4659*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 4660*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 4661*da0073e9SAndroid Build Coastguard Worker def test_validator_tunableop_rocm(self, device, dtype): 4662*da0073e9SAndroid Build Coastguard Worker # Test that the validator on ROCM has exactly 5 lines 4663*da0073e9SAndroid Build Coastguard Worker # Format of the Validator is as follows: 4664*da0073e9SAndroid Build Coastguard Worker # Validator,PT_VERSION,X.Y.Z. 4665*da0073e9SAndroid Build Coastguard Worker # Validator,ROCBLAS_VERSION,X.Y,Z 4666*da0073e9SAndroid Build Coastguard Worker # Validator,HIPBLASLT_VERSION,X,Y.Z 4667*da0073e9SAndroid Build Coastguard Worker # Validator,ROCM_Version,X,Y.Z 4668*da0073e9SAndroid Build Coastguard Worker # Validator,GCN_ARCH_NAME,<architecutre name> 4669*da0073e9SAndroid Build Coastguard Worker validator_num_lines = 5 4670*da0073e9SAndroid Build Coastguard Worker 4671*da0073e9SAndroid Build Coastguard Worker # Test in try-finally block to avoid leaking state 4672*da0073e9SAndroid Build Coastguard Worker # if test is interrupted. 4673*da0073e9SAndroid Build Coastguard Worker try: 4674*da0073e9SAndroid Build Coastguard Worker set_tunableop_defaults() 4675*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable() 4676*da0073e9SAndroid Build Coastguard Worker # set these to single iterations to keep it short but still exercise the code 4677*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(1) 4678*da0073e9SAndroid Build Coastguard Worker 4679*da0073e9SAndroid Build Coastguard Worker N = M = K = 4 4680*da0073e9SAndroid Build Coastguard Worker A = torch.randn(N, K, device=device, dtype=dtype) 4681*da0073e9SAndroid Build Coastguard Worker B = torch.randn(K, M, device=device, dtype=dtype) 4682*da0073e9SAndroid Build Coastguard Worker C = torch.matmul(A, B) 4683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines) 4684*da0073e9SAndroid Build Coastguard Worker finally: 4685*da0073e9SAndroid Build Coastguard Worker # disable TunableOp 4686*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4687*da0073e9SAndroid Build Coastguard Worker 4688*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 4689*da0073e9SAndroid Build Coastguard Worker try: 4690*da0073e9SAndroid Build Coastguard Worker import os 4691*da0073e9SAndroid Build Coastguard Worker filename = torch.cuda.tunable.get_filename() 4692*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4693*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4694*da0073e9SAndroid Build Coastguard Worker pass 4695*da0073e9SAndroid Build Coastguard Worker 4696*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4697*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 4698*da0073e9SAndroid Build Coastguard Worker def test_minimum_tuning_iteration_tunableop(self, device, dtype): 4699*da0073e9SAndroid Build Coastguard Worker # Make sure that there is at least one tuning iteration under various scenarios 4700*da0073e9SAndroid Build Coastguard Worker 4701*da0073e9SAndroid Build Coastguard Worker # Test in try-finally block to avoid leaking state 4702*da0073e9SAndroid Build Coastguard Worker # if test is interrupted. 4703*da0073e9SAndroid Build Coastguard Worker try: 4704*da0073e9SAndroid Build Coastguard Worker set_tunableop_defaults() 4705*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable() 4706*da0073e9SAndroid Build Coastguard Worker # set these to single iterations to keep it short but still exercise the code 4707*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(1) 4708*da0073e9SAndroid Build Coastguard Worker 4709*da0073e9SAndroid Build Coastguard Worker # Set tuning duration to zero milliseconds 4710*da0073e9SAndroid Build Coastguard Worker # Tune a single GEMM and verify that we get a new tuning result 4711*da0073e9SAndroid Build Coastguard Worker import os 4712*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "0" 4713*da0073e9SAndroid Build Coastguard Worker self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0) 4714*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30" # reset to default 4715*da0073e9SAndroid Build Coastguard Worker 4716*da0073e9SAndroid Build Coastguard Worker # Reference number of results 4717*da0073e9SAndroid Build Coastguard Worker ref_num_results = len(torch.cuda.tunable.get_results()) 4718*da0073e9SAndroid Build Coastguard Worker 4719*da0073e9SAndroid Build Coastguard Worker N = M = K = 8 4720*da0073e9SAndroid Build Coastguard Worker A = torch.randn(N, K, device=device, dtype=dtype) 4721*da0073e9SAndroid Build Coastguard Worker B = torch.randn(K, M, device=device, dtype=dtype) 4722*da0073e9SAndroid Build Coastguard Worker C = torch.matmul(A, B) 4723*da0073e9SAndroid Build Coastguard Worker 4724*da0073e9SAndroid Build Coastguard Worker # This stores total number of cummulative results 4725*da0073e9SAndroid Build Coastguard Worker total_num_results = len(torch.cuda.tunable.get_results()) 4726*da0073e9SAndroid Build Coastguard Worker 4727*da0073e9SAndroid Build Coastguard Worker # There must be a new tuning result 4728*da0073e9SAndroid Build Coastguard Worker self.assertEqual((total_num_results - ref_num_results), 1) 4729*da0073e9SAndroid Build Coastguard Worker 4730*da0073e9SAndroid Build Coastguard Worker # Set tuning iterations to zero 4731*da0073e9SAndroid Build Coastguard Worker # Tune a single GEMM and verify that we get a new tuning result 4732*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "0" 4733*da0073e9SAndroid Build Coastguard Worker self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0) 4734*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "100" # reset to default 4735*da0073e9SAndroid Build Coastguard Worker 4736*da0073e9SAndroid Build Coastguard Worker # Reference number of results 4737*da0073e9SAndroid Build Coastguard Worker ref_num_results = total_num_results 4738*da0073e9SAndroid Build Coastguard Worker 4739*da0073e9SAndroid Build Coastguard Worker N = M = K = 16 4740*da0073e9SAndroid Build Coastguard Worker A = torch.randn(N, K, device=device, dtype=dtype) 4741*da0073e9SAndroid Build Coastguard Worker B = torch.randn(K, M, device=device, dtype=dtype) 4742*da0073e9SAndroid Build Coastguard Worker C = torch.matmul(A, B) 4743*da0073e9SAndroid Build Coastguard Worker 4744*da0073e9SAndroid Build Coastguard Worker # This stores total number of cummulative results 4745*da0073e9SAndroid Build Coastguard Worker total_num_results = len(torch.cuda.tunable.get_results()) 4746*da0073e9SAndroid Build Coastguard Worker 4747*da0073e9SAndroid Build Coastguard Worker # There must be a new tuning result 4748*da0073e9SAndroid Build Coastguard Worker self.assertEqual((total_num_results - ref_num_results), 1) 4749*da0073e9SAndroid Build Coastguard Worker 4750*da0073e9SAndroid Build Coastguard Worker finally: 4751*da0073e9SAndroid Build Coastguard Worker # disable TunableOp 4752*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4753*da0073e9SAndroid Build Coastguard Worker 4754*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 4755*da0073e9SAndroid Build Coastguard Worker try: 4756*da0073e9SAndroid Build Coastguard Worker import os 4757*da0073e9SAndroid Build Coastguard Worker filename = torch.cuda.tunable.get_filename() 4758*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4759*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4760*da0073e9SAndroid Build Coastguard Worker pass 4761*da0073e9SAndroid Build Coastguard Worker 4762*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4763*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 4764*da0073e9SAndroid Build Coastguard Worker def test_matmul_check_entries_tunableop(self, device, dtype): 4765*da0073e9SAndroid Build Coastguard Worker # Tune a couple of matrix multiplies 4766*da0073e9SAndroid Build Coastguard Worker # Verify we get the correct number of results 4767*da0073e9SAndroid Build Coastguard Worker 4768*da0073e9SAndroid Build Coastguard Worker try: 4769*da0073e9SAndroid Build Coastguard Worker set_tunableop_defaults() 4770*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable() 4771*da0073e9SAndroid Build Coastguard Worker # set these to single iterations to keep it short but still exercise the code 4772*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(1) 4773*da0073e9SAndroid Build Coastguard Worker 4774*da0073e9SAndroid Build Coastguard Worker # Reference number of results 4775*da0073e9SAndroid Build Coastguard Worker ref_num_results = len(torch.cuda.tunable.get_results()) 4776*da0073e9SAndroid Build Coastguard Worker 4777*da0073e9SAndroid Build Coastguard Worker # Execute matrix multiplies. We intentionally throw in M list the same index 4778*da0073e9SAndroid Build Coastguard Worker # twice. The CSV file should only get unique GEMMs 4779*da0073e9SAndroid Build Coastguard Worker count_matmul = 4 4780*da0073e9SAndroid Build Coastguard Worker K = 64 4781*da0073e9SAndroid Build Coastguard Worker for M in [32, 64, 32]: 4782*da0073e9SAndroid Build Coastguard Worker for N in [32, 64]: 4783*da0073e9SAndroid Build Coastguard Worker A = torch.randn(N, K, device=device, dtype=dtype) 4784*da0073e9SAndroid Build Coastguard Worker B = torch.randn(K, M, device=device, dtype=dtype) 4785*da0073e9SAndroid Build Coastguard Worker C = torch.matmul(A, B) 4786*da0073e9SAndroid Build Coastguard Worker 4787*da0073e9SAndroid Build Coastguard Worker # This stores total number of cummulative results 4788*da0073e9SAndroid Build Coastguard Worker total_num_results = len(torch.cuda.tunable.get_results()) 4789*da0073e9SAndroid Build Coastguard Worker 4790*da0073e9SAndroid Build Coastguard Worker # Take the difference to calculate the number of results from 4791*da0073e9SAndroid Build Coastguard Worker # the this test and verify that it agrees with the number of 4792*da0073e9SAndroid Build Coastguard Worker # GEMMs. 4793*da0073e9SAndroid Build Coastguard Worker self.assertEqual((total_num_results - ref_num_results), count_matmul) 4794*da0073e9SAndroid Build Coastguard Worker 4795*da0073e9SAndroid Build Coastguard Worker finally: 4796*da0073e9SAndroid Build Coastguard Worker # disable TunableOp 4797*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4798*da0073e9SAndroid Build Coastguard Worker 4799*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 4800*da0073e9SAndroid Build Coastguard Worker try: 4801*da0073e9SAndroid Build Coastguard Worker import os 4802*da0073e9SAndroid Build Coastguard Worker filename = torch.cuda.tunable.get_filename() 4803*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4804*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4805*da0073e9SAndroid Build Coastguard Worker pass 4806*da0073e9SAndroid Build Coastguard Worker 4807*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4808*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 4809*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 4810*da0073e9SAndroid Build Coastguard Worker def test_bmm_tunableop_rocm(self, device, dtype): 4811*da0073e9SAndroid Build Coastguard Worker # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault 4812*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(True) 4813*da0073e9SAndroid Build Coastguard Worker ordinal = torch.cuda.current_device() 4814*da0073e9SAndroid Build Coastguard Worker filename = f"tunableop_results{ordinal}.csv" 4815*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_filename(filename) 4816*da0073e9SAndroid Build Coastguard Worker iterations = torch.cuda.tunable.get_max_tuning_iterations() 4817*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(10) 4818*da0073e9SAndroid Build Coastguard Worker # the following 3 cases cover all previous failure cases and are here to catch regressions 4819*da0073e9SAndroid Build Coastguard Worker B = 16 4820*da0073e9SAndroid Build Coastguard Worker N = M = K = 256 4821*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 4822*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 4823*da0073e9SAndroid Build Coastguard Worker # case 1 4824*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4825*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4826*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4827*da0073e9SAndroid Build Coastguard Worker # case 2 4828*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4829*da0073e9SAndroid Build Coastguard Worker i1 = torch.permute(i1, (1, 2, 0)) 4830*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4831*da0073e9SAndroid Build Coastguard Worker i2 = torch.permute(i2, (1, 0, 2)) 4832*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4833*da0073e9SAndroid Build Coastguard Worker # case 3 4834*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((N, B, M), device=device, dtype=dtype) 4835*da0073e9SAndroid Build Coastguard Worker i1 = torch.permute(i1, (1, 0, 2)) 4836*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((M, B, K), device=device, dtype=dtype) 4837*da0073e9SAndroid Build Coastguard Worker i2 = torch.permute(i2, (1, 2, 0)) 4838*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4839*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 4840*da0073e9SAndroid Build Coastguard Worker try: 4841*da0073e9SAndroid Build Coastguard Worker import os 4842*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4843*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4844*da0073e9SAndroid Build Coastguard Worker pass 4845*da0073e9SAndroid Build Coastguard Worker # reset back to prior settings 4846*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(iterations) 4847*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4848*da0073e9SAndroid Build Coastguard Worker 4849*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4850*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 4851*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 4852*da0073e9SAndroid Build Coastguard Worker def test_numeric_check_leak_tunableop_rocm(self, device, dtype): 4853*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import CudaMemoryLeakCheck 4854*da0073e9SAndroid Build Coastguard Worker import os 4855*da0073e9SAndroid Build Coastguard Worker # run operator first without tuning to ensure all rocm libs are loaded, 4856*da0073e9SAndroid Build Coastguard Worker # otherwise false positive mem leak 4857*da0073e9SAndroid Build Coastguard Worker B = 16 4858*da0073e9SAndroid Build Coastguard Worker N = M = K = 256 4859*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 4860*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 4861*da0073e9SAndroid Build Coastguard Worker i1 = torch.randn((B, N, M), device=device, dtype=dtype) 4862*da0073e9SAndroid Build Coastguard Worker i2 = torch.randn((B, M, K), device=device, dtype=dtype) 4863*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4864*da0073e9SAndroid Build Coastguard Worker # enable tunableop numeric check via env variable. 4865*da0073e9SAndroid Build Coastguard Worker PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK" 4866*da0073e9SAndroid Build Coastguard Worker prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK) 4867*da0073e9SAndroid Build Coastguard Worker try: 4868*da0073e9SAndroid Build Coastguard Worker os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1" 4869*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(True) 4870*da0073e9SAndroid Build Coastguard Worker ordinal = torch.cuda.current_device() 4871*da0073e9SAndroid Build Coastguard Worker filename = f"tunableop_results{ordinal}.csv" 4872*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_filename(filename) 4873*da0073e9SAndroid Build Coastguard Worker iterations = torch.cuda.tunable.get_max_tuning_iterations() 4874*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(10) 4875*da0073e9SAndroid Build Coastguard Worker with CudaMemoryLeakCheck(self): 4876*da0073e9SAndroid Build Coastguard Worker out = torch.bmm(i1, i2) 4877*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(iterations) 4878*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 4879*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 4880*da0073e9SAndroid Build Coastguard Worker try: 4881*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 4882*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 4883*da0073e9SAndroid Build Coastguard Worker pass 4884*da0073e9SAndroid Build Coastguard Worker finally: 4885*da0073e9SAndroid Build Coastguard Worker if prev_val is None: 4886*da0073e9SAndroid Build Coastguard Worker del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] 4887*da0073e9SAndroid Build Coastguard Worker else: 4888*da0073e9SAndroid Build Coastguard Worker os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val 4889*da0073e9SAndroid Build Coastguard Worker 4890*da0073e9SAndroid Build Coastguard Worker 4891*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.complex64) 4892*da0073e9SAndroid Build Coastguard Worker def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): 4893*da0073e9SAndroid Build Coastguard Worker a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0) 4894*da0073e9SAndroid Build Coastguard Worker b = torch.empty((4, 128, 512), device=device, dtype=dtype, requires_grad=True).transpose(-1, -2) 4895*da0073e9SAndroid Build Coastguard Worker c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0) 4896*da0073e9SAndroid Build Coastguard Worker 4897*da0073e9SAndroid Build Coastguard Worker torch.matmul(a.detach(), b.detach(), out=c) 4898*da0073e9SAndroid Build Coastguard Worker 4899*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "functions with out=... arguments don't support automatic differentiation"): 4900*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b, out=c) 4901*da0073e9SAndroid Build Coastguard Worker 4902*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4903*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b, out=c) 4904*da0073e9SAndroid Build Coastguard Worker 4905*da0073e9SAndroid Build Coastguard Worker # 4GB should do, but we run tests in parallel in CI, so let's be generous 4906*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('16GB', device='cuda') 4907*da0073e9SAndroid Build Coastguard Worker def test_large_bmm_mm_backward(self, device): 4908*da0073e9SAndroid Build Coastguard Worker A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT 4909*da0073e9SAndroid Build Coastguard Worker B = torch.randn([1024, 65536], device="cuda", requires_grad=True) 4910*da0073e9SAndroid Build Coastguard Worker G = torch.randn([1024, 2, 65536], device="cuda") 4911*da0073e9SAndroid Build Coastguard Worker 4912*da0073e9SAndroid Build Coastguard Worker # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 4913*da0073e9SAndroid Build Coastguard Worker (A @ B).backward(G) 4914*da0073e9SAndroid Build Coastguard Worker 4915*da0073e9SAndroid Build Coastguard Worker # 4GB should do, but we run tests in parallel in CI, so let's be generous 4916*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('16GB', device='cuda') 4917*da0073e9SAndroid Build Coastguard Worker def test_large_bmm_backward(self, device): 4918*da0073e9SAndroid Build Coastguard Worker A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT 4919*da0073e9SAndroid Build Coastguard Worker B = torch.randn([1, 1024, 65536], device="cuda", requires_grad=True) 4920*da0073e9SAndroid Build Coastguard Worker G = torch.randn([1024, 2, 65536], device="cuda") 4921*da0073e9SAndroid Build Coastguard Worker 4922*da0073e9SAndroid Build Coastguard Worker # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 4923*da0073e9SAndroid Build Coastguard Worker (A @ B).backward(G) 4924*da0073e9SAndroid Build Coastguard Worker 4925*da0073e9SAndroid Build Coastguard Worker def test_linear_algebra_scalar_raises(self, device) -> None: 4926*da0073e9SAndroid Build Coastguard Worker m = torch.randn(5, 5, device=device) 4927*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, device=device) 4928*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(7, device=device) 4929*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.mv(m, s)) 4930*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s)) 4931*da0073e9SAndroid Build Coastguard Worker 4932*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.complex64) 4933*da0073e9SAndroid Build Coastguard Worker def test_cross(self, device, dtype): 4934*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 3, 100, dtype=dtype, device=device) 4935*da0073e9SAndroid Build Coastguard Worker y = torch.rand(100, 3, 100, dtype=dtype, device=device) 4936*da0073e9SAndroid Build Coastguard Worker res1 = torch.cross(x, y) 4937*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 4938*da0073e9SAndroid Build Coastguard Worker torch.cross(x, y, out=res2) 4939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 4940*da0073e9SAndroid Build Coastguard Worker 4941*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.complex64) 4942*da0073e9SAndroid Build Coastguard Worker def test_linalg_cross(self, device, dtype): 4943*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 3, 100, dtype=dtype, device=device) 4944*da0073e9SAndroid Build Coastguard Worker y = torch.rand(100, 3, 100, dtype=dtype, device=device) 4945*da0073e9SAndroid Build Coastguard Worker res1 = torch.linalg.cross(x, y, dim=1) 4946*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 4947*da0073e9SAndroid Build Coastguard Worker torch.linalg.cross(x, y, dim=1, out=res2) 4948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 4949*da0073e9SAndroid Build Coastguard Worker 4950*da0073e9SAndroid Build Coastguard Worker # test for broadcastable inputs 4951*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 3, 2, dtype=dtype, device=device) 4952*da0073e9SAndroid Build Coastguard Worker y = torch.rand(4, 3, 1, dtype=dtype, device=device) 4953*da0073e9SAndroid Build Coastguard Worker res1 = torch.linalg.cross(x, y, dim=1) 4954*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 4955*da0073e9SAndroid Build Coastguard Worker torch.linalg.cross(x, y, dim=1, out=res2) 4956*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 4957*da0073e9SAndroid Build Coastguard Worker 4958*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.complex64) 4959*da0073e9SAndroid Build Coastguard Worker def test_cross_with_and_without_dim(self, device, dtype): 4960*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 3, dtype=dtype, device=device) 4961*da0073e9SAndroid Build Coastguard Worker y = torch.rand(100, 3, dtype=dtype, device=device) 4962*da0073e9SAndroid Build Coastguard Worker res1 = torch.cross(x, y, dim=1) 4963*da0073e9SAndroid Build Coastguard Worker res2 = torch.cross(x, y, dim=-1) 4964*da0073e9SAndroid Build Coastguard Worker res3 = torch.cross(x, y) 4965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 4966*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res3) 4967*da0073e9SAndroid Build Coastguard Worker 4968*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.complex64) 4969*da0073e9SAndroid Build Coastguard Worker def test_linalg_cross_with_and_without_dim(self, device, dtype): 4970*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 3, dtype=dtype, device=device) 4971*da0073e9SAndroid Build Coastguard Worker y = torch.rand(100, 3, dtype=dtype, device=device) 4972*da0073e9SAndroid Build Coastguard Worker res1 = torch.linalg.cross(x, y, dim=1) 4973*da0073e9SAndroid Build Coastguard Worker res2 = torch.linalg.cross(x, y, dim=-1) 4974*da0073e9SAndroid Build Coastguard Worker res3 = torch.linalg.cross(x, y) 4975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 4976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res3) 4977*da0073e9SAndroid Build Coastguard Worker 4978*da0073e9SAndroid Build Coastguard Worker def test_renorm(self, device): 4979*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(20, 20, device=device) # big enough to exercise vectorized path 4980*da0073e9SAndroid Build Coastguard Worker res1 = torch.tensor((), device=device) 4981*da0073e9SAndroid Build Coastguard Worker 4982*da0073e9SAndroid Build Coastguard Worker def renorm(matrix, value, dim, max_norm): 4983*da0073e9SAndroid Build Coastguard Worker m1 = matrix.transpose(dim, 0).contiguous() 4984*da0073e9SAndroid Build Coastguard Worker # collapse non-dim dimensions. 4985*da0073e9SAndroid Build Coastguard Worker m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0)))) 4986*da0073e9SAndroid Build Coastguard Worker norms = m2.norm(value, 1, True) 4987*da0073e9SAndroid Build Coastguard Worker # clip 4988*da0073e9SAndroid Build Coastguard Worker new_norms = norms.clone() 4989*da0073e9SAndroid Build Coastguard Worker new_norms[torch.gt(norms, max_norm)] = max_norm 4990*da0073e9SAndroid Build Coastguard Worker new_norms.div_(norms.add_(1e-7)) 4991*da0073e9SAndroid Build Coastguard Worker # renormalize 4992*da0073e9SAndroid Build Coastguard Worker m1.mul_(new_norms.expand_as(m1)) 4993*da0073e9SAndroid Build Coastguard Worker return m1.transpose(dim, 0) 4994*da0073e9SAndroid Build Coastguard Worker 4995*da0073e9SAndroid Build Coastguard Worker # note that the axis fed to torch.renorm is different (2~=1) 4996*da0073e9SAndroid Build Coastguard Worker maxnorm = m1.norm(2, 1).mean() 4997*da0073e9SAndroid Build Coastguard Worker m2 = renorm(m1, 2, 1, maxnorm) 4998*da0073e9SAndroid Build Coastguard Worker m1.renorm_(2, 1, maxnorm) 4999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1, m2, atol=1e-5, rtol=0) 5000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0) 5001*da0073e9SAndroid Build Coastguard Worker 5002*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(3, 4, 5, device=device) 5003*da0073e9SAndroid Build Coastguard Worker m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) 5004*da0073e9SAndroid Build Coastguard Worker maxnorm = m2.norm(2, 0).mean() 5005*da0073e9SAndroid Build Coastguard Worker m2 = renorm(m2, 2, 1, maxnorm) 5006*da0073e9SAndroid Build Coastguard Worker m1.renorm_(2, 1, maxnorm) 5007*da0073e9SAndroid Build Coastguard Worker m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) 5008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m3, m2) 5009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m3.norm(2, 0), m2.norm(2, 0)) 5010*da0073e9SAndroid Build Coastguard Worker 5011*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5012*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 5013*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 5014*da0073e9SAndroid Build Coastguard Worker def test_ormqr(self, device, dtype): 5015*da0073e9SAndroid Build Coastguard Worker 5016*da0073e9SAndroid Build Coastguard Worker def run_test(batch, m, n, fortran_contiguous): 5017*da0073e9SAndroid Build Coastguard Worker A = make_tensor((*batch, m, n), dtype=dtype, device=device) 5018*da0073e9SAndroid Build Coastguard Worker reflectors, tau = torch.geqrf(A) 5019*da0073e9SAndroid Build Coastguard Worker if not fortran_contiguous: 5020*da0073e9SAndroid Build Coastguard Worker self.assertTrue(reflectors.mT.is_contiguous()) 5021*da0073e9SAndroid Build Coastguard Worker reflectors = reflectors.contiguous() 5022*da0073e9SAndroid Build Coastguard Worker 5023*da0073e9SAndroid Build Coastguard Worker # Q is of size m x m 5024*da0073e9SAndroid Build Coastguard Worker Q, _ = torch.linalg.qr(A, mode='complete') 5025*da0073e9SAndroid Build Coastguard Worker C_right = make_tensor((*batch, m, n), dtype=dtype, device=device) 5026*da0073e9SAndroid Build Coastguard Worker C_left = make_tensor((*batch, n, m), dtype=dtype, device=device) 5027*da0073e9SAndroid Build Coastguard Worker 5028*da0073e9SAndroid Build Coastguard Worker expected = Q @ C_right 5029*da0073e9SAndroid Build Coastguard Worker actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=False) 5030*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 5031*da0073e9SAndroid Build Coastguard Worker 5032*da0073e9SAndroid Build Coastguard Worker expected = C_left @ Q 5033*da0073e9SAndroid Build Coastguard Worker actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=False) 5034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 5035*da0073e9SAndroid Build Coastguard Worker 5036*da0073e9SAndroid Build Coastguard Worker expected = Q.mH @ C_right 5037*da0073e9SAndroid Build Coastguard Worker actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=True) 5038*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 5039*da0073e9SAndroid Build Coastguard Worker 5040*da0073e9SAndroid Build Coastguard Worker expected = C_left @ Q.mH 5041*da0073e9SAndroid Build Coastguard Worker actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=True) 5042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 5043*da0073e9SAndroid Build Coastguard Worker 5044*da0073e9SAndroid Build Coastguard Worker # if tau is all zeros then the implicit matrix Q is the identity matrix 5045*da0073e9SAndroid Build Coastguard Worker # so the actual result should be C_right in this case 5046*da0073e9SAndroid Build Coastguard Worker zero_tau = torch.zeros_like(tau) 5047*da0073e9SAndroid Build Coastguard Worker actual = torch.ormqr(reflectors, zero_tau, C_right, left=True, transpose=False) 5048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(C_right, actual) 5049*da0073e9SAndroid Build Coastguard Worker 5050*da0073e9SAndroid Build Coastguard Worker batches = [(), (0, ), (2, ), (2, 1)] 5051*da0073e9SAndroid Build Coastguard Worker ns = [5, 2, 0] 5052*da0073e9SAndroid Build Coastguard Worker for batch, (m, n), fortran_contiguous in product(batches, product(ns, ns), [True, False]): 5053*da0073e9SAndroid Build Coastguard Worker run_test(batch, m, n, fortran_contiguous) 5054*da0073e9SAndroid Build Coastguard Worker 5055*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5056*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 5057*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 5058*da0073e9SAndroid Build Coastguard Worker def test_ormqr_errors_and_warnings(self, device, dtype): 5059*da0073e9SAndroid Build Coastguard Worker test_cases = [ 5060*da0073e9SAndroid Build Coastguard Worker # input1 size, input2 size, input3 size, error regex 5061*da0073e9SAndroid Build Coastguard Worker ((10,), (2,), (2,), r"input must have at least 2 dimensions"), 5062*da0073e9SAndroid Build Coastguard Worker ((2, 2), (2,), (2,), r"other must have at least 2 dimensions"), 5063*da0073e9SAndroid Build Coastguard Worker ((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"), 5064*da0073e9SAndroid Build Coastguard Worker ((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"), 5065*da0073e9SAndroid Build Coastguard Worker ((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"), 5066*da0073e9SAndroid Build Coastguard Worker ((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"), 5067*da0073e9SAndroid Build Coastguard Worker ] 5068*da0073e9SAndroid Build Coastguard Worker for a_size, tau_size, c_size, error_regex in test_cases: 5069*da0073e9SAndroid Build Coastguard Worker a = make_tensor(a_size, dtype=dtype, device=device) 5070*da0073e9SAndroid Build Coastguard Worker tau = make_tensor(tau_size, dtype=dtype, device=device) 5071*da0073e9SAndroid Build Coastguard Worker c = make_tensor(c_size, dtype=dtype, device=device) 5072*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 5073*da0073e9SAndroid Build Coastguard Worker torch.ormqr(a, tau, c) 5074*da0073e9SAndroid Build Coastguard Worker 5075*da0073e9SAndroid Build Coastguard Worker def test_blas_empty(self, device): 5076*da0073e9SAndroid Build Coastguard Worker def fn(torchfn, *args, test_out=False, **kwargs): 5077*da0073e9SAndroid Build Coastguard Worker def call_torch_fn(*args, **kwargs): 5078*da0073e9SAndroid Build Coastguard Worker return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape 5079*da0073e9SAndroid Build Coastguard Worker for shape in args), **kwargs) 5080*da0073e9SAndroid Build Coastguard Worker result = call_torch_fn(*args, **kwargs) 5081*da0073e9SAndroid Build Coastguard Worker if not test_out: 5082*da0073e9SAndroid Build Coastguard Worker return result 5083*da0073e9SAndroid Build Coastguard Worker else: 5084*da0073e9SAndroid Build Coastguard Worker out = torch.full_like(result, math.nan) 5085*da0073e9SAndroid Build Coastguard Worker out1 = call_torch_fn(*args, **kwargs, out=out) 5086*da0073e9SAndroid Build Coastguard Worker return out 5087*da0073e9SAndroid Build Coastguard Worker 5088*da0073e9SAndroid Build Coastguard Worker # mm, addmm 5089*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) 5090*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) 5091*da0073e9SAndroid Build Coastguard Worker self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) 5092*da0073e9SAndroid Build Coastguard Worker self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) 5093*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))) 5094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True)) 5095*da0073e9SAndroid Build Coastguard Worker 5096*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) 5097*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape) 5098*da0073e9SAndroid Build Coastguard Worker t = torch.randn((5, 6), device=device) 5099*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) 5100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) 5101*da0073e9SAndroid Build Coastguard Worker 5102*da0073e9SAndroid Build Coastguard Worker # mv, addmv 5103*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) 5104*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) 5105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) 5106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)) 5107*da0073e9SAndroid Build Coastguard Worker 5108*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) 5109*da0073e9SAndroid Build Coastguard Worker t = torch.randn((3,), device=device) 5110*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) 5111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) 5112*da0073e9SAndroid Build Coastguard Worker 5113*da0073e9SAndroid Build Coastguard Worker # bmm, baddbmm 5114*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) 5115*da0073e9SAndroid Build Coastguard Worker self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) 5116*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) 5117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))) 5118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True)) 5119*da0073e9SAndroid Build Coastguard Worker 5120*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape) 5121*da0073e9SAndroid Build Coastguard Worker self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape) 5122*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape) 5123*da0073e9SAndroid Build Coastguard Worker self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape) 5124*da0073e9SAndroid Build Coastguard Worker c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) 5125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)) # Issue #33467 5126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)) # Issue #33467 5127*da0073e9SAndroid Build Coastguard Worker 5128*da0073e9SAndroid Build Coastguard Worker # addbmm 5129*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) 5130*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) 5131*da0073e9SAndroid Build Coastguard Worker t = torch.randn((5, 6), device=device) 5132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) 5133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) 5134*da0073e9SAndroid Build Coastguard Worker 5135*da0073e9SAndroid Build Coastguard Worker # matmul 5136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,))) 5137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True)) 5138*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) 5139*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) 5140*da0073e9SAndroid Build Coastguard Worker self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) 5141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4))) 5142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True)) 5143*da0073e9SAndroid Build Coastguard Worker 5144*da0073e9SAndroid Build Coastguard Worker # dot 5145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) 5146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True)) 5147*da0073e9SAndroid Build Coastguard Worker 5148*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, 5149*da0073e9SAndroid Build Coastguard Worker torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 5150*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_and_complex_types_and( 5151*da0073e9SAndroid Build Coastguard Worker torch.half, 5152*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16] if SM53OrLater else [] 5153*da0073e9SAndroid Build Coastguard Worker )) 5154*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bfloat16)) 5155*da0073e9SAndroid Build Coastguard Worker def test_corner_cases_of_cublasltmatmul(self, device, dtype): 5156*da0073e9SAndroid Build Coastguard Worker # common case 5157*da0073e9SAndroid Build Coastguard Worker M = torch.randn(128, device=device).to(dtype) 5158*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(2048, 2400, device=device).to(dtype) 5159*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(128, 2400, device=device).to(dtype) 5160*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(m1, m2, M) 5161*da0073e9SAndroid Build Coastguard Worker # Ntrans_B has ld >> rows 5162*da0073e9SAndroid Build Coastguard Worker m1 = torch.rand([128, 2400]).to(dtype).to(device).t() 5163*da0073e9SAndroid Build Coastguard Worker m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340] 5164*da0073e9SAndroid Build Coastguard Worker M = torch.rand([128]).to(dtype).to(device) 5165*da0073e9SAndroid Build Coastguard Worker torch.addmm(M, m2.t(), m1) 5166*da0073e9SAndroid Build Coastguard Worker # trans_A has ld >> rows 5167*da0073e9SAndroid Build Coastguard Worker m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t() 5168*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(2048, 2400, device=device).to(dtype) 5169*da0073e9SAndroid Build Coastguard Worker M = torch.rand([128]).to(dtype).to(device) 5170*da0073e9SAndroid Build Coastguard Worker torch.addmm(M, m2, m1) 5171*da0073e9SAndroid Build Coastguard Worker # large tensor dim > 65535 5172*da0073e9SAndroid Build Coastguard Worker M = torch.randn(16, device=device).to(dtype) 5173*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(32, 131071 , device=device).to(dtype) 5174*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(16, 131071, device=device).to(dtype) 5175*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(m1, m2, M) 5176*da0073e9SAndroid Build Coastguard Worker 5177*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5178*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 5179*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.bfloat16, torch.half)) 5180*da0073e9SAndroid Build Coastguard Worker def test_hipblaslt_corner_cases_rocm(self, device, dtype): 5181*da0073e9SAndroid Build Coastguard Worker if dtype == torch.double: 5182*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("hipblasLt doesn't support doubles yet") 5183*da0073e9SAndroid Build Coastguard Worker 5184*da0073e9SAndroid Build Coastguard Worker # enable hipblaslt path via env variable. 5185*da0073e9SAndroid Build Coastguard Worker import os 5186*da0073e9SAndroid Build Coastguard Worker DISABLE_ADDMM_HIP_LT = "DISABLE_ADDMM_HIP_LT" 5187*da0073e9SAndroid Build Coastguard Worker prev_val = os.getenv(DISABLE_ADDMM_HIP_LT) 5188*da0073e9SAndroid Build Coastguard Worker try: 5189*da0073e9SAndroid Build Coastguard Worker os.environ[DISABLE_ADDMM_HIP_LT] = "0" 5190*da0073e9SAndroid Build Coastguard Worker # common case 5191*da0073e9SAndroid Build Coastguard Worker M = torch.randn(128, device=device, dtype=dtype) 5192*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(2048, 2400, device=device, dtype=dtype) 5193*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(128, 2400, device=device, dtype=dtype) 5194*da0073e9SAndroid Build Coastguard Worker out1 = torch.nn.functional.linear(m1, m2, M) 5195*da0073e9SAndroid Build Coastguard Worker M_cpu = M.to('cpu') 5196*da0073e9SAndroid Build Coastguard Worker m1_cpu = m1.to('cpu') 5197*da0073e9SAndroid Build Coastguard Worker m2_cpu = m2.to('cpu') 5198*da0073e9SAndroid Build Coastguard Worker out1_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, M_cpu) 5199*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out1_cpu, out1.cpu(), rtol=1e-2, atol=1e-2)) 5200*da0073e9SAndroid Build Coastguard Worker 5201*da0073e9SAndroid Build Coastguard Worker # common case without bias 5202*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(2048, 2400, device=device, dtype=dtype) 5203*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(128, 2400, device=device, dtype=dtype) 5204*da0073e9SAndroid Build Coastguard Worker out2 = torch.nn.functional.linear(m1, m2, bias=None) 5205*da0073e9SAndroid Build Coastguard Worker m1_cpu = m1.to('cpu') 5206*da0073e9SAndroid Build Coastguard Worker m2_cpu = m2.to('cpu') 5207*da0073e9SAndroid Build Coastguard Worker out2_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, bias=None) 5208*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out2_cpu, out2.cpu(), rtol=1e-2, atol=1e-2)) 5209*da0073e9SAndroid Build Coastguard Worker finally: 5210*da0073e9SAndroid Build Coastguard Worker if prev_val is None: 5211*da0073e9SAndroid Build Coastguard Worker del os.environ[DISABLE_ADDMM_HIP_LT] 5212*da0073e9SAndroid Build Coastguard Worker else: 5213*da0073e9SAndroid Build Coastguard Worker os.environ[DISABLE_ADDMM_HIP_LT] = prev_val 5214*da0073e9SAndroid Build Coastguard Worker 5215*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_and_complex_types_and( 5216*da0073e9SAndroid Build Coastguard Worker torch.half, 5217*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16] if SM53OrLater else [] 5218*da0073e9SAndroid Build Coastguard Worker )) 5219*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.half)) 5220*da0073e9SAndroid Build Coastguard Worker def test_blas_alpha_beta_empty(self, device, dtype): 5221*da0073e9SAndroid Build Coastguard Worker # This test is disabled on CUDA 9 due to: 5222*da0073e9SAndroid Build Coastguard Worker # See: https://github.com/pytorch/pytorch/issues/31006 5223*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16 and self.device_type == 'xla': 5224*da0073e9SAndroid Build Coastguard Worker # TODO (@zasdfgbnm): this causes the following error on test 5225*da0073e9SAndroid Build Coastguard Worker # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16: 5226*da0073e9SAndroid Build Coastguard Worker # 5227*da0073e9SAndroid Build Coastguard Worker # RuntimeError: _th_equal not supported on CPUType for BFloat16 5228*da0073e9SAndroid Build Coastguard Worker return 5229*da0073e9SAndroid Build Coastguard Worker # ensure beta is respected 5230*da0073e9SAndroid Build Coastguard Worker value = 11 5231*da0073e9SAndroid Build Coastguard Worker input = torch.full((2,), value, dtype=dtype, device=device) 5232*da0073e9SAndroid Build Coastguard Worker mat = torch.ones((2, 0), dtype=dtype, device=device) 5233*da0073e9SAndroid Build Coastguard Worker vec = torch.ones((0,), dtype=dtype, device=device) 5234*da0073e9SAndroid Build Coastguard Worker out = torch.empty((2,), dtype=dtype, device=device) 5235*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 5236*da0073e9SAndroid Build Coastguard Worker alpha = 6 + 7j 5237*da0073e9SAndroid Build Coastguard Worker beta = 3 + 4j 5238*da0073e9SAndroid Build Coastguard Worker else: 5239*da0073e9SAndroid Build Coastguard Worker alpha = 6 5240*da0073e9SAndroid Build Coastguard Worker beta = 3 5241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), 5242*da0073e9SAndroid Build Coastguard Worker torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta)) 5243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), 5244*da0073e9SAndroid Build Coastguard Worker torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out)) 5245*da0073e9SAndroid Build Coastguard Worker 5246*da0073e9SAndroid Build Coastguard Worker # torch.addmm 5247*da0073e9SAndroid Build Coastguard Worker input = torch.full((2, 3), value, dtype=dtype, device=device) 5248*da0073e9SAndroid Build Coastguard Worker mat2 = torch.ones((0, 3), dtype=dtype, device=device) 5249*da0073e9SAndroid Build Coastguard Worker out = torch.empty((2, 3), dtype=dtype, device=device) 5250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), 5251*da0073e9SAndroid Build Coastguard Worker torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta)) 5252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), 5253*da0073e9SAndroid Build Coastguard Worker torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out)) 5254*da0073e9SAndroid Build Coastguard Worker 5255*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16)) 5256*da0073e9SAndroid Build Coastguard Worker def test_blas_nan_out(self, device, dtype): 5257*da0073e9SAndroid Build Coastguard Worker # These functions should work correctly with NaN filled outputs, 5258*da0073e9SAndroid Build Coastguard Worker # but need special handling, see [NOTE: cpu_zero] 5259*da0073e9SAndroid Build Coastguard Worker b = 3 5260*da0073e9SAndroid Build Coastguard Worker n = 5 5261*da0073e9SAndroid Build Coastguard Worker m = 7 5262*da0073e9SAndroid Build Coastguard Worker p = 11 5263*da0073e9SAndroid Build Coastguard Worker 5264*da0073e9SAndroid Build Coastguard Worker # torch.mv 5265*da0073e9SAndroid Build Coastguard Worker nm = torch.randn((m, n), device=device).t() 5266*da0073e9SAndroid Build Coastguard Worker _m = torch.randn((), device=device).expand(m) 5267*da0073e9SAndroid Build Coastguard Worker _m_out = torch.full((m,), float('nan'), device=device) 5268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) 5269*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum()) 5270*da0073e9SAndroid Build Coastguard Worker 5271*da0073e9SAndroid Build Coastguard Worker # torch.mm 5272*da0073e9SAndroid Build Coastguard Worker mp = torch.randn((p, m), device=device).t() 5273*da0073e9SAndroid Build Coastguard Worker np_out = torch.full((n, p), float('nan'), device=device) 5274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out)) 5275*da0073e9SAndroid Build Coastguard Worker 5276*da0073e9SAndroid Build Coastguard Worker # torch.bmm 5277*da0073e9SAndroid Build Coastguard Worker bnm = torch.randn((b, m, n), device=device).transpose(1, 2) 5278*da0073e9SAndroid Build Coastguard Worker bmp = torch.randn((b, p, m), device=device).transpose(1, 2) 5279*da0073e9SAndroid Build Coastguard Worker bnp_out = torch.full((b, n, p), float('nan'), device=device) 5280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out)) 5281*da0073e9SAndroid Build Coastguard Worker 5282*da0073e9SAndroid Build Coastguard Worker @onlyCPU # not supported by CUBLAS 5283*da0073e9SAndroid Build Coastguard Worker def test_blas_mv_large_input(self, device): 5284*da0073e9SAndroid Build Coastguard Worker # This would previously fail if the allocated output had NaNs, see: 5285*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero] 5286*da0073e9SAndroid Build Coastguard Worker n = 3000 5287*da0073e9SAndroid Build Coastguard Worker m = 200 5288*da0073e9SAndroid Build Coastguard Worker 5289*da0073e9SAndroid Build Coastguard Worker nm = torch.randn((m, n), device=device).t() 5290*da0073e9SAndroid Build Coastguard Worker _m = torch.randn((), device=device).expand(m) 5291*da0073e9SAndroid Build Coastguard Worker _m_out = torch.full((m,), 0., device=device) 5292*da0073e9SAndroid Build Coastguard Worker 5293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) 5294*da0073e9SAndroid Build Coastguard Worker 5295*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5296*da0073e9SAndroid Build Coastguard Worker def test_renorm_ps(self, device): 5297*da0073e9SAndroid Build Coastguard Worker # full reduction 5298*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 5299*da0073e9SAndroid Build Coastguard Worker xn = x.numpy() 5300*da0073e9SAndroid Build Coastguard Worker for p in [1, 2, 3, 4, inf]: 5301*da0073e9SAndroid Build Coastguard Worker res = x.renorm(p, 1, 1) 5302*da0073e9SAndroid Build Coastguard Worker expected = x / x.norm(p, 0, keepdim=True).clamp(min=1) 5303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, msg=f"renorm failed for {p}-norm") 5304*da0073e9SAndroid Build Coastguard Worker 5305*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5306*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 5307*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 5308*da0073e9SAndroid Build Coastguard Worker def test_householder_product(self, device, dtype): 5309*da0073e9SAndroid Build Coastguard Worker def generate_reflectors_and_tau(A): 5310*da0073e9SAndroid Build Coastguard Worker """ 5311*da0073e9SAndroid Build Coastguard Worker This function uses numpy.linalg.qr with mode "raw" to extract output of LAPACK's geqrf. 5312*da0073e9SAndroid Build Coastguard Worker There is torch.geqrf function but it doesn't work with complex-valued input. 5313*da0073e9SAndroid Build Coastguard Worker """ 5314*da0073e9SAndroid Build Coastguard Worker if A.numel() > 0: 5315*da0073e9SAndroid Build Coastguard Worker A_cpu = A.cpu() 5316*da0073e9SAndroid Build Coastguard Worker flattened_batch_shape = [-1, *A_cpu.shape[-2:]] 5317*da0073e9SAndroid Build Coastguard Worker reflectors = torch.empty_like(A_cpu).view(*flattened_batch_shape) 5318*da0073e9SAndroid Build Coastguard Worker tau_shape = [*A_cpu.shape[:-2], A_cpu.shape[-1]] 5319*da0073e9SAndroid Build Coastguard Worker tau = torch.empty(tau_shape, dtype=dtype).view(-1, A_cpu.shape[-1]) 5320*da0073e9SAndroid Build Coastguard Worker for A_i, reflectors_i, tau_i in zip(A_cpu.contiguous().view(*flattened_batch_shape), reflectors, tau): 5321*da0073e9SAndroid Build Coastguard Worker reflectors_tmp, tau_i[:] = map(torch.from_numpy, np.linalg.qr(A_i, mode='raw')) 5322*da0073e9SAndroid Build Coastguard Worker reflectors_i[:] = reflectors_tmp.T 5323*da0073e9SAndroid Build Coastguard Worker reflectors = reflectors.view(*A_cpu.shape) 5324*da0073e9SAndroid Build Coastguard Worker tau = tau.view(tau_shape) 5325*da0073e9SAndroid Build Coastguard Worker return reflectors.to(A.device), tau.to(A.device) 5326*da0073e9SAndroid Build Coastguard Worker 5327*da0073e9SAndroid Build Coastguard Worker reflectors = torch.empty_like(A) 5328*da0073e9SAndroid Build Coastguard Worker tau = torch.empty(*A.shape[:-2], A.shape[-1], dtype=dtype, device=device) 5329*da0073e9SAndroid Build Coastguard Worker return reflectors, tau 5330*da0073e9SAndroid Build Coastguard Worker 5331*da0073e9SAndroid Build Coastguard Worker def run_test(shape): 5332*da0073e9SAndroid Build Coastguard Worker A = torch.randn(*shape, dtype=dtype, device=device) 5333*da0073e9SAndroid Build Coastguard Worker reflectors, tau = generate_reflectors_and_tau(A) 5334*da0073e9SAndroid Build Coastguard Worker expected, _ = torch.linalg.qr(A) 5335*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.householder_product(reflectors, tau) 5336*da0073e9SAndroid Build Coastguard Worker # torch.linalg.qr does not work correctly for zero batch dimension tensors 5337*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/50576 5338*da0073e9SAndroid Build Coastguard Worker if (A.numel() > 0): 5339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 5340*da0073e9SAndroid Build Coastguard Worker else: 5341*da0073e9SAndroid Build Coastguard Worker self.assertTrue(actual.shape == shape) 5342*da0073e9SAndroid Build Coastguard Worker 5343*da0073e9SAndroid Build Coastguard Worker # if tau is empty and A is not the result should be a matrix with ones on the diagonal 5344*da0073e9SAndroid Build Coastguard Worker if (A.numel() > 0): 5345*da0073e9SAndroid Build Coastguard Worker tau_empty = torch.empty(*shape[:-2], 0, dtype=dtype, device=device) 5346*da0073e9SAndroid Build Coastguard Worker identity_mat = torch.zeros_like(reflectors) 5347*da0073e9SAndroid Build Coastguard Worker identity_mat.diagonal(dim1=-1, dim2=-2)[:] = 1 5348*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.householder_product(reflectors, tau_empty) 5349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, identity_mat) 5350*da0073e9SAndroid Build Coastguard Worker 5351*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(A) 5352*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.householder_product(reflectors, tau, out=out) 5353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 5354*da0073e9SAndroid Build Coastguard Worker if (A.numel() > 0): 5355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 5356*da0073e9SAndroid Build Coastguard Worker 5357*da0073e9SAndroid Build Coastguard Worker shapes = [(0, 0), (5, 0), # Empty matrix 5358*da0073e9SAndroid Build Coastguard Worker (5, 5), (5, 3), # Single matrix 5359*da0073e9SAndroid Build Coastguard Worker (0, 0, 0), (0, 5, 5), (0, 5, 3), # Zero batch dimension tensors 5360*da0073e9SAndroid Build Coastguard Worker (2, 5, 5), (2, 5, 3), # 3-dim tensors 5361*da0073e9SAndroid Build Coastguard Worker (2, 1, 5, 5), (2, 1, 5, 3)] # 4-dim tensors 5362*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 5363*da0073e9SAndroid Build Coastguard Worker run_test(shape) 5364*da0073e9SAndroid Build Coastguard Worker 5365*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5366*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 5367*da0073e9SAndroid Build Coastguard Worker def test_householder_product_errors_and_warnings(self, device): 5368*da0073e9SAndroid Build Coastguard Worker test_cases = [ 5369*da0073e9SAndroid Build Coastguard Worker # input1 size, input2 size, error regex 5370*da0073e9SAndroid Build Coastguard Worker ((10,), (2,), r"input must have at least 2 dimensions"), 5371*da0073e9SAndroid Build Coastguard Worker ((10, 6), (20,), r"input.shape\[-1\] must be greater than or equal to tau.shape\[-1\]"), 5372*da0073e9SAndroid Build Coastguard Worker ((6, 10), (5,), r"input.shape\[-2\] must be greater than or equal to input.shape\[-1\]"), 5373*da0073e9SAndroid Build Coastguard Worker ] 5374*da0073e9SAndroid Build Coastguard Worker for a_size, tau_size, error_regex in test_cases: 5375*da0073e9SAndroid Build Coastguard Worker a = torch.rand(*a_size, device=device) 5376*da0073e9SAndroid Build Coastguard Worker tau = torch.rand(*tau_size, device=device) 5377*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 5378*da0073e9SAndroid Build Coastguard Worker torch.linalg.householder_product(a, tau) 5379*da0073e9SAndroid Build Coastguard Worker 5380*da0073e9SAndroid Build Coastguard Worker # if out tensor with wrong shape is passed a warning is given 5381*da0073e9SAndroid Build Coastguard Worker reflectors = torch.randn(3, 3, device=device) 5382*da0073e9SAndroid Build Coastguard Worker tau = torch.randn(3, device=device) 5383*da0073e9SAndroid Build Coastguard Worker out = torch.empty(2, 3, device=device) 5384*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5385*da0073e9SAndroid Build Coastguard Worker # Trigger warning 5386*da0073e9SAndroid Build Coastguard Worker torch.linalg.householder_product(reflectors, tau, out=out) 5387*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 5388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 5389*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 5390*da0073e9SAndroid Build Coastguard Worker 5391*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 5392*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(reflectors).to(torch.int) 5393*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 5394*da0073e9SAndroid Build Coastguard Worker torch.linalg.householder_product(reflectors, tau, out=out) 5395*da0073e9SAndroid Build Coastguard Worker 5396*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tau dtype Int does not match input dtype"): 5397*da0073e9SAndroid Build Coastguard Worker torch.linalg.householder_product(reflectors, tau.to(torch.int)) 5398*da0073e9SAndroid Build Coastguard Worker 5399*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 5400*da0073e9SAndroid Build Coastguard Worker # device of out and input should match 5401*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 5402*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(reflectors).to(wrong_device) 5403*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 5404*da0073e9SAndroid Build Coastguard Worker torch.linalg.householder_product(reflectors, tau, out=out) 5405*da0073e9SAndroid Build Coastguard Worker 5406*da0073e9SAndroid Build Coastguard Worker # device of tau and input should match 5407*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 5408*da0073e9SAndroid Build Coastguard Worker tau = tau.to(wrong_device) 5409*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 5410*da0073e9SAndroid Build Coastguard Worker torch.linalg.householder_product(reflectors, tau) 5411*da0073e9SAndroid Build Coastguard Worker 5412*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) 5413*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 5414*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Runtime error with torch._C._linalg.linalg_lu_factor") 5415*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5416*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 5417*da0073e9SAndroid Build Coastguard Worker def test_linalg_lu_family(self, device, dtype): 5418*da0073e9SAndroid Build Coastguard Worker # Tests torch.lu 5419*da0073e9SAndroid Build Coastguard Worker # torch.linalg.lu_factor 5420*da0073e9SAndroid Build Coastguard Worker # torch.linalg.lu_factor_ex 5421*da0073e9SAndroid Build Coastguard Worker # torch.lu_unpack 5422*da0073e9SAndroid Build Coastguard Worker # torch.linalg.lu_solve 5423*da0073e9SAndroid Build Coastguard Worker # torch.linalg.solve 5424*da0073e9SAndroid Build Coastguard Worker make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype) 5425*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 5426*da0073e9SAndroid Build Coastguard Worker 5427*da0073e9SAndroid Build Coastguard Worker def run_test(A, pivot, singular, fn): 5428*da0073e9SAndroid Build Coastguard Worker k = min(A.shape[-2:]) 5429*da0073e9SAndroid Build Coastguard Worker batch = A.shape[:-2] 5430*da0073e9SAndroid Build Coastguard Worker check_errors = (fn == torch.linalg.lu_factor) 5431*da0073e9SAndroid Build Coastguard Worker if singular and check_errors: 5432*da0073e9SAndroid Build Coastguard Worker # It may or may not throw as the LU decomposition without pivoting 5433*da0073e9SAndroid Build Coastguard Worker # may still succeed for singular matrices 5434*da0073e9SAndroid Build Coastguard Worker try: 5435*da0073e9SAndroid Build Coastguard Worker LU, pivots = fn(A, pivot=pivot) 5436*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 5437*da0073e9SAndroid Build Coastguard Worker return 5438*da0073e9SAndroid Build Coastguard Worker else: 5439*da0073e9SAndroid Build Coastguard Worker LU, pivots = fn(A, pivot=pivot)[:2] 5440*da0073e9SAndroid Build Coastguard Worker 5441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LU.size(), A.shape) 5442*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pivots.size(), batch + (k,)) 5443*da0073e9SAndroid Build Coastguard Worker 5444*da0073e9SAndroid Build Coastguard Worker if not pivot: 5445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, ))) 5446*da0073e9SAndroid Build Coastguard Worker 5447*da0073e9SAndroid Build Coastguard Worker P, L, U = torch.lu_unpack(LU, pivots, unpack_pivots=pivot) 5448*da0073e9SAndroid Build Coastguard Worker 5449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(P @ L @ U if pivot else L @ U, A) 5450*da0073e9SAndroid Build Coastguard Worker 5451*da0073e9SAndroid Build Coastguard Worker PLU = torch.linalg.lu(A, pivot=pivot) 5452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(P, PLU.P) 5453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(L, PLU.L) 5454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(U, PLU.U) 5455*da0073e9SAndroid Build Coastguard Worker 5456*da0073e9SAndroid Build Coastguard Worker if not singular and A.size(-2) == A.size(-1): 5457*da0073e9SAndroid Build Coastguard Worker nrhs = ((), (1,), (3,)) 5458*da0073e9SAndroid Build Coastguard Worker for left, rhs in product((True, False), nrhs): 5459*da0073e9SAndroid Build Coastguard Worker # Vector case when left = False is not allowed 5460*da0073e9SAndroid Build Coastguard Worker if not left and rhs == (): 5461*da0073e9SAndroid Build Coastguard Worker continue 5462*da0073e9SAndroid Build Coastguard Worker if left: 5463*da0073e9SAndroid Build Coastguard Worker shape_B = A.shape[:-1] + rhs 5464*da0073e9SAndroid Build Coastguard Worker else: 5465*da0073e9SAndroid Build Coastguard Worker shape_B = A.shape[:-2] + rhs + A.shape[-1:] 5466*da0073e9SAndroid Build Coastguard Worker B = make_arg(shape_B) 5467*da0073e9SAndroid Build Coastguard Worker 5468*da0073e9SAndroid Build Coastguard Worker # Test linalg.lu_solve. It does not support vectors as rhs 5469*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913 5470*da0073e9SAndroid Build Coastguard Worker if rhs != (): 5471*da0073e9SAndroid Build Coastguard Worker for adjoint in (True, False): 5472*da0073e9SAndroid Build Coastguard Worker X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint) 5473*da0073e9SAndroid Build Coastguard Worker A_adj = A.mH if adjoint else A 5474*da0073e9SAndroid Build Coastguard Worker if left: 5475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(B, A_adj @ X) 5476*da0073e9SAndroid Build Coastguard Worker else: 5477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(B, X @ A_adj) 5478*da0073e9SAndroid Build Coastguard Worker 5479*da0073e9SAndroid Build Coastguard Worker # Test linalg.solve 5480*da0073e9SAndroid Build Coastguard Worker X = torch.linalg.solve(A, B, left=left) 5481*da0073e9SAndroid Build Coastguard Worker X_ = X.unsqueeze(-1) if rhs == () else X 5482*da0073e9SAndroid Build Coastguard Worker B_ = B.unsqueeze(-1) if rhs == () else B 5483*da0073e9SAndroid Build Coastguard Worker if left: 5484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(B_, A @ X_) 5485*da0073e9SAndroid Build Coastguard Worker else: 5486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(B_, X_ @ A) 5487*da0073e9SAndroid Build Coastguard Worker 5488*da0073e9SAndroid Build Coastguard Worker 5489*da0073e9SAndroid Build Coastguard Worker sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) 5490*da0073e9SAndroid Build Coastguard Worker batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5)) 5491*da0073e9SAndroid Build Coastguard Worker # Non pivoting just implemented for CUDA 5492*da0073e9SAndroid Build Coastguard Worker pivots = (True, False) if self.device_type == "cuda" else (True,) 5493*da0073e9SAndroid Build Coastguard Worker fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex) 5494*da0073e9SAndroid Build Coastguard Worker for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns): 5495*da0073e9SAndroid Build Coastguard Worker shape = batch + ms 5496*da0073e9SAndroid Build Coastguard Worker A = make_arg(shape) if singular else make_arg_full(*shape) 5497*da0073e9SAndroid Build Coastguard Worker # Just do one of them on singular matrices 5498*da0073e9SAndroid Build Coastguard Worker if A.numel() == 0 and not singular: 5499*da0073e9SAndroid Build Coastguard Worker continue 5500*da0073e9SAndroid Build Coastguard Worker run_test(A, pivot, singular, fn) 5501*da0073e9SAndroid Build Coastguard Worker 5502*da0073e9SAndroid Build Coastguard Worker # Reproducer of a magma bug, 5503*da0073e9SAndroid Build Coastguard Worker # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on 5504*da0073e9SAndroid Build Coastguard Worker # This is also a bug in cuSOLVER < 11.3 5505*da0073e9SAndroid Build Coastguard Worker if (dtype == torch.double 5506*da0073e9SAndroid Build Coastguard Worker and singular): 5507*da0073e9SAndroid Build Coastguard Worker A = torch.ones(batch + ms, dtype=dtype, device=device) 5508*da0073e9SAndroid Build Coastguard Worker run_test(A, pivot, singular, fn) 5509*da0073e9SAndroid Build Coastguard Worker 5510*da0073e9SAndroid Build Coastguard Worker # Info should be positive for rank deficient matrices 5511*da0073e9SAndroid Build Coastguard Worker A = torch.ones(5, 3, 3, device=device) 5512*da0073e9SAndroid Build Coastguard Worker self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all()) 5513*da0073e9SAndroid Build Coastguard Worker 5514*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 5515*da0073e9SAndroid Build Coastguard Worker # Error checking, no pivoting variant on CPU 5516*da0073e9SAndroid Build Coastguard Worker fns = [torch.lu, torch.linalg.lu_factor, torch.linalg.lu_factor_ex, torch.linalg.lu] 5517*da0073e9SAndroid Build Coastguard Worker for f in fns: 5518*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'): 5519*da0073e9SAndroid Build Coastguard Worker f(torch.empty(1, 2, 2), pivot=False) 5520*da0073e9SAndroid Build Coastguard Worker 5521*da0073e9SAndroid Build Coastguard Worker 5522*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) 5523*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 5524*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5525*da0073e9SAndroid Build Coastguard Worker @setLinalgBackendsToDefaultFinally 5526*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 5527*da0073e9SAndroid Build Coastguard Worker def test_linalg_lu_solve(self, device, dtype): 5528*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 5529*da0073e9SAndroid Build Coastguard Worker 5530*da0073e9SAndroid Build Coastguard Worker backends = ["default"] 5531*da0073e9SAndroid Build Coastguard Worker 5532*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 5533*da0073e9SAndroid Build Coastguard Worker if torch.cuda.has_magma: 5534*da0073e9SAndroid Build Coastguard Worker backends.append("magma") 5535*da0073e9SAndroid Build Coastguard Worker if has_cusolver(): 5536*da0073e9SAndroid Build Coastguard Worker backends.append("cusolver") 5537*da0073e9SAndroid Build Coastguard Worker 5538*da0073e9SAndroid Build Coastguard Worker def gen_matrices(): 5539*da0073e9SAndroid Build Coastguard Worker rhs = 3 5540*da0073e9SAndroid Build Coastguard Worker ns = (5, 2, 0) 5541*da0073e9SAndroid Build Coastguard Worker batches = ((), (0,), (1,), (2,), (2, 1), (0, 2)) 5542*da0073e9SAndroid Build Coastguard Worker for batch, n in product(batches, ns): 5543*da0073e9SAndroid Build Coastguard Worker yield make_arg(batch + (n, n)), make_arg(batch + (n, rhs)) 5544*da0073e9SAndroid Build Coastguard Worker # Shapes to exercise all the paths 5545*da0073e9SAndroid Build Coastguard Worker shapes = ((1, 64), (2, 128), (1025, 2)) 5546*da0073e9SAndroid Build Coastguard Worker for b, n in shapes: 5547*da0073e9SAndroid Build Coastguard Worker yield make_arg((b, n, n)), make_arg((b, n, rhs)) 5548*da0073e9SAndroid Build Coastguard Worker 5549*da0073e9SAndroid Build Coastguard Worker 5550*da0073e9SAndroid Build Coastguard Worker for A, B in gen_matrices(): 5551*da0073e9SAndroid Build Coastguard Worker LU, pivots = torch.linalg.lu_factor(A) 5552*da0073e9SAndroid Build Coastguard Worker for backend in backends: 5553*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_linalg_library(backend) 5554*da0073e9SAndroid Build Coastguard Worker 5555*da0073e9SAndroid Build Coastguard Worker for left, adjoint in product((True, False), repeat=2): 5556*da0073e9SAndroid Build Coastguard Worker B_left = B if left else B.mT 5557*da0073e9SAndroid Build Coastguard Worker X = torch.linalg.lu_solve(LU, pivots, B_left, left=left, adjoint=adjoint) 5558*da0073e9SAndroid Build Coastguard Worker A_adj = A.mH if adjoint else A 5559*da0073e9SAndroid Build Coastguard Worker if left: 5560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(B_left, A_adj @ X) 5561*da0073e9SAndroid Build Coastguard Worker else: 5562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(B_left, X @ A_adj) 5563*da0073e9SAndroid Build Coastguard Worker 5564*da0073e9SAndroid Build Coastguard Worker 5565*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5566*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 5567*da0073e9SAndroid Build Coastguard Worker def test_linalg_lu_cpu_errors(self, device, dtype): 5568*da0073e9SAndroid Build Coastguard Worker # Square tests 5569*da0073e9SAndroid Build Coastguard Worker sample = torch.randn(3, 2, 2, device=device, dtype=dtype) 5570*da0073e9SAndroid Build Coastguard Worker B = torch.randn(3, 2, 2, device=device, dtype=dtype) 5571*da0073e9SAndroid Build Coastguard Worker LU, pivots = torch.linalg.lu_factor(sample) 5572*da0073e9SAndroid Build Coastguard Worker 5573*da0073e9SAndroid Build Coastguard Worker # This should run without issues 5574*da0073e9SAndroid Build Coastguard Worker torch.linalg.lu_solve(LU, pivots, B, adjoint=True) 5575*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5576*da0073e9SAndroid Build Coastguard Worker 5577*da0073e9SAndroid Build Coastguard Worker pivots[0] = 0 5578*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"greater or equal to 1"): 5579*da0073e9SAndroid Build Coastguard Worker torch.linalg.lu_solve(LU, pivots, B, adjoint=True) 5580*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5581*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5582*da0073e9SAndroid Build Coastguard Worker 5583*da0073e9SAndroid Build Coastguard Worker pivots[0] = 3 5584*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"smaller or equal to LU.size\(-2\)"): 5585*da0073e9SAndroid Build Coastguard Worker torch.linalg.lu_solve(LU, pivots, B, adjoint=True) 5586*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5587*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5588*da0073e9SAndroid Build Coastguard Worker 5589*da0073e9SAndroid Build Coastguard Worker # Rectangular tests 5590*da0073e9SAndroid Build Coastguard Worker sample = torch.randn(3, 4, 2, device=device, dtype=dtype) 5591*da0073e9SAndroid Build Coastguard Worker B = torch.randn(3, 4, 2, device=device, dtype=dtype) 5592*da0073e9SAndroid Build Coastguard Worker LU, pivots = torch.linalg.lu_factor(sample) 5593*da0073e9SAndroid Build Coastguard Worker 5594*da0073e9SAndroid Build Coastguard Worker # This should run without issues 5595*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5596*da0073e9SAndroid Build Coastguard Worker 5597*da0073e9SAndroid Build Coastguard Worker pivots[0] = 0 5598*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5599*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5600*da0073e9SAndroid Build Coastguard Worker 5601*da0073e9SAndroid Build Coastguard Worker pivots[0] = 5 5602*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5603*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5604*da0073e9SAndroid Build Coastguard Worker 5605*da0073e9SAndroid Build Coastguard Worker 5606*da0073e9SAndroid Build Coastguard Worker # Rectangular tests 5607*da0073e9SAndroid Build Coastguard Worker sample = torch.randn(2, 3, 5, device=device, dtype=dtype) 5608*da0073e9SAndroid Build Coastguard Worker B = torch.randn(2, 3, 5, device=device, dtype=dtype) 5609*da0073e9SAndroid Build Coastguard Worker LU, pivots = torch.linalg.lu_factor(sample) 5610*da0073e9SAndroid Build Coastguard Worker 5611*da0073e9SAndroid Build Coastguard Worker # This should run without issues 5612*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5613*da0073e9SAndroid Build Coastguard Worker 5614*da0073e9SAndroid Build Coastguard Worker pivots[0] = 0 5615*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5616*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5617*da0073e9SAndroid Build Coastguard Worker 5618*da0073e9SAndroid Build Coastguard Worker pivots[0] = 4 5619*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): 5620*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(LU, pivots) 5621*da0073e9SAndroid Build Coastguard Worker 5622*da0073e9SAndroid Build Coastguard Worker 5623*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5624*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 5625*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 5626*da0073e9SAndroid Build Coastguard Worker def test_lu_unpack_check_input(self, device, dtype): 5627*da0073e9SAndroid Build Coastguard Worker x = torch.rand(5, 5, 5, device=device, dtype=dtype) 5628*da0073e9SAndroid Build Coastguard Worker lu_data, lu_pivots = torch.linalg.lu_factor(x) 5629*da0073e9SAndroid Build Coastguard Worker 5630*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"): 5631*da0073e9SAndroid Build Coastguard Worker torch.lu_unpack(lu_data, lu_pivots.long()) 5632*da0073e9SAndroid Build Coastguard Worker 5633*da0073e9SAndroid Build Coastguard Worker # check that onces flags are unset, Nones are returned 5634*da0073e9SAndroid Build Coastguard Worker p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False) 5635*da0073e9SAndroid Build Coastguard Worker self.assertTrue(l.numel() == 0 and u.numel() == 0) 5636*da0073e9SAndroid Build Coastguard Worker p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False) 5637*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p.numel() == 0) 5638*da0073e9SAndroid Build Coastguard Worker p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False) 5639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p.numel() == 0 and l.numel() == 0 and u.numel() == 0) 5640*da0073e9SAndroid Build Coastguard Worker 5641*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 5642*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5643*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 5644*da0073e9SAndroid Build Coastguard Worker def test_lobpcg_basic(self, device, dtype): 5645*da0073e9SAndroid Build Coastguard Worker self._test_lobpcg_method(device, dtype, 'basic') 5646*da0073e9SAndroid Build Coastguard Worker 5647*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 5648*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5649*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 5650*da0073e9SAndroid Build Coastguard Worker def test_lobpcg_ortho(self, device, dtype): 5651*da0073e9SAndroid Build Coastguard Worker if torch.version.hip: 5652*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_linalg_library('magma') 5653*da0073e9SAndroid Build Coastguard Worker self._test_lobpcg_method(device, dtype, 'ortho') 5654*da0073e9SAndroid Build Coastguard Worker if torch.version.hip: 5655*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_linalg_library('default') 5656*da0073e9SAndroid Build Coastguard Worker 5657*da0073e9SAndroid Build Coastguard Worker def _test_lobpcg_method(self, device, dtype, method): 5658*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix 5659*da0073e9SAndroid Build Coastguard Worker from torch._linalg_utils import matmul, qform 5660*da0073e9SAndroid Build Coastguard Worker from torch._lobpcg import lobpcg 5661*da0073e9SAndroid Build Coastguard Worker 5662*da0073e9SAndroid Build Coastguard Worker def test_tracker(worker): 5663*da0073e9SAndroid Build Coastguard Worker k = worker.iparams['k'] 5664*da0073e9SAndroid Build Coastguard Worker nc = worker.ivars['converged_count'] 5665*da0073e9SAndroid Build Coastguard Worker if k <= nc: 5666*da0073e9SAndroid Build Coastguard Worker tol = worker.fparams['tol'] 5667*da0073e9SAndroid Build Coastguard Worker rerr = worker.tvars['rerr'] 5668*da0073e9SAndroid Build Coastguard Worker X = worker.X 5669*da0073e9SAndroid Build Coastguard Worker E = worker.E 5670*da0073e9SAndroid Build Coastguard Worker B = worker.B 5671*da0073e9SAndroid Build Coastguard Worker A = worker.A 5672*da0073e9SAndroid Build Coastguard Worker dtype = X.dtype 5673*da0073e9SAndroid Build Coastguard Worker device = X.device 5674*da0073e9SAndroid Build Coastguard Worker 5675*da0073e9SAndroid Build Coastguard Worker # Check convergence 5676*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(rerr[:k].max(), tol) 5677*da0073e9SAndroid Build Coastguard Worker 5678*da0073e9SAndroid Build Coastguard Worker # Check B-orthogonality 5679*da0073e9SAndroid Build Coastguard Worker I = torch.eye(k, k, dtype=dtype, device=device) 5680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(qform(B, X[:, :k]), I) 5681*da0073e9SAndroid Build Coastguard Worker 5682*da0073e9SAndroid Build Coastguard Worker # Check block equation 5683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0) 5684*da0073e9SAndroid Build Coastguard Worker 5685*da0073e9SAndroid Build Coastguard Worker orig_lobpcg = lobpcg 5686*da0073e9SAndroid Build Coastguard Worker 5687*da0073e9SAndroid Build Coastguard Worker def lobpcg(*args, **kwargs): 5688*da0073e9SAndroid Build Coastguard Worker kwargs['tracker'] = test_tracker 5689*da0073e9SAndroid Build Coastguard Worker kwargs['niter'] = 1000 5690*da0073e9SAndroid Build Coastguard Worker kwargs['method'] = method 5691*da0073e9SAndroid Build Coastguard Worker kwargs['tol'] = 1e-8 5692*da0073e9SAndroid Build Coastguard Worker return orig_lobpcg(*args, **kwargs) 5693*da0073e9SAndroid Build Coastguard Worker prec = 5e-4 5694*da0073e9SAndroid Build Coastguard Worker 5695*da0073e9SAndroid Build Coastguard Worker # check dense input 5696*da0073e9SAndroid Build Coastguard Worker mm = torch.matmul 5697*da0073e9SAndroid Build Coastguard Worker for batches in [(), (2,), (2, 3)]: 5698*da0073e9SAndroid Build Coastguard Worker for m, n, k in [ 5699*da0073e9SAndroid Build Coastguard Worker (9, 3, 1), 5700*da0073e9SAndroid Build Coastguard Worker (9, 3, 2), 5701*da0073e9SAndroid Build Coastguard Worker (9, 2, 2), 5702*da0073e9SAndroid Build Coastguard Worker (100, 15, 5), 5703*da0073e9SAndroid Build Coastguard Worker ]: 5704*da0073e9SAndroid Build Coastguard Worker # skip tests that are known to fail with the basic 5705*da0073e9SAndroid Build Coastguard Worker # LOBPCG method due to calling cholesky on singular 5706*da0073e9SAndroid Build Coastguard Worker # input 5707*da0073e9SAndroid Build Coastguard Worker if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]: 5708*da0073e9SAndroid Build Coastguard Worker continue 5709*da0073e9SAndroid Build Coastguard Worker A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) 5710*da0073e9SAndroid Build Coastguard Worker B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) 5711*da0073e9SAndroid Build Coastguard Worker 5712*da0073e9SAndroid Build Coastguard Worker # classical eigenvalue problem, smallest eigenvalues 5713*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, k=k, n=n, largest=False) 5714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(E.shape, batches + (k,)) 5715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(V.shape, batches + (m, k)) 5716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5717*da0073e9SAndroid Build Coastguard Worker e = torch.linalg.eigvalsh(A) 5718*da0073e9SAndroid Build Coastguard Worker e_smallest = e[..., :k] 5719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(E, e_smallest) 5720*da0073e9SAndroid Build Coastguard Worker 5721*da0073e9SAndroid Build Coastguard Worker # classical eigenvalue problem, largest eigenvalues 5722*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, k=k, n=n, largest=True) 5723*da0073e9SAndroid Build Coastguard Worker e_largest, _ = torch.sort(e[..., -k:], descending=True) 5724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(E, e_largest, atol=prec, rtol=0) 5725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5726*da0073e9SAndroid Build Coastguard Worker 5727*da0073e9SAndroid Build Coastguard Worker # generalized eigenvalue problem, smallest eigenvalues 5728*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, B=B, k=k, n=n, largest=False) 5729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0) 5730*da0073e9SAndroid Build Coastguard Worker 5731*da0073e9SAndroid Build Coastguard Worker # generalized eigenvalue problem, largest eigenvalues 5732*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, B=B, k=k, n=n, largest=True) 5733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), 5734*da0073e9SAndroid Build Coastguard Worker atol=prec, rtol=0) 5735*da0073e9SAndroid Build Coastguard Worker 5736*da0073e9SAndroid Build Coastguard Worker # check sparse input 5737*da0073e9SAndroid Build Coastguard Worker for m, n, k, density in [ 5738*da0073e9SAndroid Build Coastguard Worker (5, 1, 1, 0.8), 5739*da0073e9SAndroid Build Coastguard Worker (9, 3, 2, 0.5), 5740*da0073e9SAndroid Build Coastguard Worker (100, 1, 1, 0.1), 5741*da0073e9SAndroid Build Coastguard Worker (1000, 7, 3, 0.01), 5742*da0073e9SAndroid Build Coastguard Worker ]: 5743*da0073e9SAndroid Build Coastguard Worker # skip tests that are known to fail with the basic LOBCG 5744*da0073e9SAndroid Build Coastguard Worker # method due to insufficient accuracy 5745*da0073e9SAndroid Build Coastguard Worker if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]: 5746*da0073e9SAndroid Build Coastguard Worker continue 5747*da0073e9SAndroid Build Coastguard Worker A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) 5748*da0073e9SAndroid Build Coastguard Worker B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) 5749*da0073e9SAndroid Build Coastguard Worker A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m 5750*da0073e9SAndroid Build Coastguard Worker e_smallest = A_eigenvalues[..., :k] 5751*da0073e9SAndroid Build Coastguard Worker e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True) 5752*da0073e9SAndroid Build Coastguard Worker 5753*da0073e9SAndroid Build Coastguard Worker # classical eigenvalue problem, smallest eigenvalues 5754*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, k=k, n=n, largest=False) 5755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(E, e_smallest) 5756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5757*da0073e9SAndroid Build Coastguard Worker 5758*da0073e9SAndroid Build Coastguard Worker # classical eigenvalue problem, largest eigenvalues 5759*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, k=k, n=n, largest=True) 5760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) 5761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(E, e_largest) 5762*da0073e9SAndroid Build Coastguard Worker 5763*da0073e9SAndroid Build Coastguard Worker # generalized eigenvalue problem, smallest eigenvalues 5764*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, B=B, k=k, n=n, largest=False) 5765*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0) 5766*da0073e9SAndroid Build Coastguard Worker 5767*da0073e9SAndroid Build Coastguard Worker # generalized eigenvalue problem, largest eigenvalues 5768*da0073e9SAndroid Build Coastguard Worker E, V = lobpcg(A, B=B, k=k, n=n, largest=True) 5769*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), 5770*da0073e9SAndroid Build Coastguard Worker atol=prec, rtol=0) 5771*da0073e9SAndroid Build Coastguard Worker 5772*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5773*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5774*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 5775*da0073e9SAndroid Build Coastguard Worker def test_lobpcg_torchscript(self, device, dtype): 5776*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_sparse_pd_matrix 5777*da0073e9SAndroid Build Coastguard Worker from torch._linalg_utils import matmul as mm 5778*da0073e9SAndroid Build Coastguard Worker 5779*da0073e9SAndroid Build Coastguard Worker lobpcg = torch.jit.script(torch.lobpcg) 5780*da0073e9SAndroid Build Coastguard Worker 5781*da0073e9SAndroid Build Coastguard Worker m = 500 5782*da0073e9SAndroid Build Coastguard Worker k = 5 5783*da0073e9SAndroid Build Coastguard Worker A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) 5784*da0073e9SAndroid Build Coastguard Worker X1 = torch.randn((m, k), dtype=dtype, device=device) 5785*da0073e9SAndroid Build Coastguard Worker E1, V1 = lobpcg(A1, X=X1) 5786*da0073e9SAndroid Build Coastguard Worker eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() 5787*da0073e9SAndroid Build Coastguard Worker self.assertLess(eq_err, 1e-6) 5788*da0073e9SAndroid Build Coastguard Worker 5789*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1") 5790*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 5791*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("fails in tracing scipy.sparse.lobpcg") 5792*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5793*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 5794*da0073e9SAndroid Build Coastguard Worker def test_lobpcg_scipy(self, device, dtype): 5795*da0073e9SAndroid Build Coastguard Worker """Compare torch and scipy.sparse.linalg implementations of lobpcg 5796*da0073e9SAndroid Build Coastguard Worker """ 5797*da0073e9SAndroid Build Coastguard Worker import time 5798*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_sparse_pd_matrix 5799*da0073e9SAndroid Build Coastguard Worker from torch._linalg_utils import matmul as mm 5800*da0073e9SAndroid Build Coastguard Worker from scipy.sparse.linalg import lobpcg as scipy_lobpcg 5801*da0073e9SAndroid Build Coastguard Worker import scipy.sparse 5802*da0073e9SAndroid Build Coastguard Worker 5803*da0073e9SAndroid Build Coastguard Worker def toscipy(A): 5804*da0073e9SAndroid Build Coastguard Worker if A.layout == torch.sparse_coo: 5805*da0073e9SAndroid Build Coastguard Worker values = A.coalesce().values().cpu().numpy().copy() 5806*da0073e9SAndroid Build Coastguard Worker indices = A.coalesce().indices().cpu().numpy().copy() 5807*da0073e9SAndroid Build Coastguard Worker return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape) 5808*da0073e9SAndroid Build Coastguard Worker return A.cpu().numpy().copy() 5809*da0073e9SAndroid Build Coastguard Worker 5810*da0073e9SAndroid Build Coastguard Worker niter = 1000 5811*da0073e9SAndroid Build Coastguard Worker repeat = 10 5812*da0073e9SAndroid Build Coastguard Worker m = 500 # size of the square matrix 5813*da0073e9SAndroid Build Coastguard Worker k = 7 # the number of requested eigenpairs 5814*da0073e9SAndroid Build Coastguard Worker A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) 5815*da0073e9SAndroid Build Coastguard Worker B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) 5816*da0073e9SAndroid Build Coastguard Worker X1 = torch.randn((m, k), dtype=dtype, device=device) 5817*da0073e9SAndroid Build Coastguard Worker 5818*da0073e9SAndroid Build Coastguard Worker A2 = toscipy(A1) 5819*da0073e9SAndroid Build Coastguard Worker B2 = toscipy(B1) 5820*da0073e9SAndroid Build Coastguard Worker X2 = toscipy(X1) 5821*da0073e9SAndroid Build Coastguard Worker 5822*da0073e9SAndroid Build Coastguard Worker lambdas1 = [] 5823*da0073e9SAndroid Build Coastguard Worker 5824*da0073e9SAndroid Build Coastguard Worker def tracker(worker): 5825*da0073e9SAndroid Build Coastguard Worker lambdas1.append(worker.E[:]) 5826*da0073e9SAndroid Build Coastguard Worker 5827*da0073e9SAndroid Build Coastguard Worker tol = 1e-8 5828*da0073e9SAndroid Build Coastguard Worker # tol for scipy lobpcg will be choosed so that the number of 5829*da0073e9SAndroid Build Coastguard Worker # iterations will be equal or very close to pytorch lobpcg 5830*da0073e9SAndroid Build Coastguard Worker # (that is around 170-180) 5831*da0073e9SAndroid Build Coastguard Worker 5832*da0073e9SAndroid Build Coastguard Worker # Standard eigenvalue problem 5833*da0073e9SAndroid Build Coastguard Worker E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) 5834*da0073e9SAndroid Build Coastguard Worker E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol) 5835*da0073e9SAndroid Build Coastguard Worker iters1 = len(lambdas1) 5836*da0073e9SAndroid Build Coastguard Worker iters2 = len(lambdas2) 5837*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) 5838*da0073e9SAndroid Build Coastguard Worker 5839*da0073e9SAndroid Build Coastguard Worker E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False) 5840*da0073e9SAndroid Build Coastguard Worker 5841*da0073e9SAndroid Build Coastguard Worker eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() 5842*da0073e9SAndroid Build Coastguard Worker eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() 5843*da0073e9SAndroid Build Coastguard Worker self.assertLess(eq_err, 1e-6) # std 5844*da0073e9SAndroid Build Coastguard Worker self.assertLess(eq_err_scipy, 1e-6) # std 5845*da0073e9SAndroid Build Coastguard Worker 5846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(E1, torch.from_numpy(E2.copy())) 5847*da0073e9SAndroid Build Coastguard Worker 5848*da0073e9SAndroid Build Coastguard Worker # Generalized eigenvalue problem 5849*da0073e9SAndroid Build Coastguard Worker lambdas1 = [] 5850*da0073e9SAndroid Build Coastguard Worker 5851*da0073e9SAndroid Build Coastguard Worker def tracker(worker): 5852*da0073e9SAndroid Build Coastguard Worker lambdas1.append(worker.E[:]) 5853*da0073e9SAndroid Build Coastguard Worker 5854*da0073e9SAndroid Build Coastguard Worker E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) 5855*da0073e9SAndroid Build Coastguard Worker E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol) 5856*da0073e9SAndroid Build Coastguard Worker E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False) 5857*da0073e9SAndroid Build Coastguard Worker iters1 = len(lambdas1) 5858*da0073e9SAndroid Build Coastguard Worker iters2 = len(lambdas2) 5859*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) 5860*da0073e9SAndroid Build Coastguard Worker 5861*da0073e9SAndroid Build Coastguard Worker eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() 5862*da0073e9SAndroid Build Coastguard Worker eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() 5863*da0073e9SAndroid Build Coastguard Worker self.assertLess(eq_err, 1e-6) # general 5864*da0073e9SAndroid Build Coastguard Worker self.assertLess(eq_err_scipy, 1e-6) # general 5865*da0073e9SAndroid Build Coastguard Worker 5866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(E1, torch.from_numpy(E2.copy())) 5867*da0073e9SAndroid Build Coastguard Worker 5868*da0073e9SAndroid Build Coastguard Worker # Timings 5869*da0073e9SAndroid Build Coastguard Worker elapsed_ortho = 0 5870*da0073e9SAndroid Build Coastguard Worker elapsed_ortho_general = 0 5871*da0073e9SAndroid Build Coastguard Worker elapsed_scipy = 0 5872*da0073e9SAndroid Build Coastguard Worker elapsed_general_scipy = 0 5873*da0073e9SAndroid Build Coastguard Worker for i in range(repeat): 5874*da0073e9SAndroid Build Coastguard Worker start = time.time() 5875*da0073e9SAndroid Build Coastguard Worker torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol) 5876*da0073e9SAndroid Build Coastguard Worker end = time.time() 5877*da0073e9SAndroid Build Coastguard Worker elapsed_ortho += end - start 5878*da0073e9SAndroid Build Coastguard Worker 5879*da0073e9SAndroid Build Coastguard Worker start = time.time() 5880*da0073e9SAndroid Build Coastguard Worker torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol) 5881*da0073e9SAndroid Build Coastguard Worker end = time.time() 5882*da0073e9SAndroid Build Coastguard Worker elapsed_ortho_general += end - start 5883*da0073e9SAndroid Build Coastguard Worker 5884*da0073e9SAndroid Build Coastguard Worker start = time.time() 5885*da0073e9SAndroid Build Coastguard Worker scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol) 5886*da0073e9SAndroid Build Coastguard Worker end = time.time() 5887*da0073e9SAndroid Build Coastguard Worker elapsed_scipy += end - start 5888*da0073e9SAndroid Build Coastguard Worker 5889*da0073e9SAndroid Build Coastguard Worker start = time.time() 5890*da0073e9SAndroid Build Coastguard Worker scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol) 5891*da0073e9SAndroid Build Coastguard Worker end = time.time() 5892*da0073e9SAndroid Build Coastguard Worker elapsed_general_scipy += end - start 5893*da0073e9SAndroid Build Coastguard Worker 5894*da0073e9SAndroid Build Coastguard Worker elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat 5895*da0073e9SAndroid Build Coastguard Worker elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat 5896*da0073e9SAndroid Build Coastguard Worker elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat 5897*da0073e9SAndroid Build Coastguard Worker elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat 5898*da0073e9SAndroid Build Coastguard Worker 5899*da0073e9SAndroid Build Coastguard Worker print(f''' 5900*da0073e9SAndroid Build Coastguard WorkerCPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg 5901*da0073e9SAndroid Build Coastguard Worker------------------------------------------------------- 5902*da0073e9SAndroid Build Coastguard Worker | standard | generalized | method 5903*da0073e9SAndroid Build Coastguard Workertorch.lobpcg | {elapsed_ortho_ms:10.2f} | {elapsed_ortho_general_ms:10.2f} | ortho 5904*da0073e9SAndroid Build Coastguard Workerscipy_lobpcg | {elapsed_scipy_ms:10.2f} | {elapsed_general_scipy_ms:10.2f} | N/A 5905*da0073e9SAndroid Build Coastguard Worker-(input size: {m:4}, eigenpairs:{k:2}, units: ms per call)- 5906*da0073e9SAndroid Build Coastguard Worker ''') 5907*da0073e9SAndroid Build Coastguard Worker 5908*da0073e9SAndroid Build Coastguard Worker # Handling of very small tolerence 5909*da0073e9SAndroid Build Coastguard Worker tol = 1e-100 5910*da0073e9SAndroid Build Coastguard Worker 5911*da0073e9SAndroid Build Coastguard Worker lambdas1 = [] 5912*da0073e9SAndroid Build Coastguard Worker 5913*da0073e9SAndroid Build Coastguard Worker def tracker(worker): 5914*da0073e9SAndroid Build Coastguard Worker lambdas1.append(worker.E[:]) 5915*da0073e9SAndroid Build Coastguard Worker 5916*da0073e9SAndroid Build Coastguard Worker E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) 5917*da0073e9SAndroid Build Coastguard Worker iters1 = len(lambdas1) 5918*da0073e9SAndroid Build Coastguard Worker eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() 5919*da0073e9SAndroid Build Coastguard Worker 5920*da0073e9SAndroid Build Coastguard Worker try: 5921*da0073e9SAndroid Build Coastguard Worker E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) 5922*da0073e9SAndroid Build Coastguard Worker iters2 = len(lambdas2) 5923*da0073e9SAndroid Build Coastguard Worker eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() 5924*da0073e9SAndroid Build Coastguard Worker except Exception as msg: 5925*da0073e9SAndroid Build Coastguard Worker print('Calling scipy_lobpcg failed [standard]:', msg) 5926*da0073e9SAndroid Build Coastguard Worker iters2 = -1 5927*da0073e9SAndroid Build Coastguard Worker eq_err_scipy = -1 5928*da0073e9SAndroid Build Coastguard Worker 5929*da0073e9SAndroid Build Coastguard Worker lambdas1 = [] 5930*da0073e9SAndroid Build Coastguard Worker 5931*da0073e9SAndroid Build Coastguard Worker def tracker(worker): 5932*da0073e9SAndroid Build Coastguard Worker lambdas1.append(worker.E[:]) 5933*da0073e9SAndroid Build Coastguard Worker 5934*da0073e9SAndroid Build Coastguard Worker E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol) 5935*da0073e9SAndroid Build Coastguard Worker iters1_general = len(lambdas1) 5936*da0073e9SAndroid Build Coastguard Worker eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() 5937*da0073e9SAndroid Build Coastguard Worker 5938*da0073e9SAndroid Build Coastguard Worker try: 5939*da0073e9SAndroid Build Coastguard Worker E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) 5940*da0073e9SAndroid Build Coastguard Worker iters2_general = len(lambdas2) 5941*da0073e9SAndroid Build Coastguard Worker eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() 5942*da0073e9SAndroid Build Coastguard Worker except Exception as msg: 5943*da0073e9SAndroid Build Coastguard Worker print('Calling scipy_lobpcg failed [generalized]:', msg) 5944*da0073e9SAndroid Build Coastguard Worker iters2_general = -1 5945*da0073e9SAndroid Build Coastguard Worker eq_err_general_scipy = -1 5946*da0073e9SAndroid Build Coastguard Worker 5947*da0073e9SAndroid Build Coastguard Worker print(f'''\ 5948*da0073e9SAndroid Build Coastguard WorkerHandling of small tol={tol:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg 5949*da0073e9SAndroid Build Coastguard Worker---------------------------------------------------------------------------- 5950*da0073e9SAndroid Build Coastguard Worker | standard | generalized | niter | method 5951*da0073e9SAndroid Build Coastguard Workertorch.lobpcg | {eq_err:10.2e} | {eq_err_general:10.2e} | {iters1:6} | ortho 5952*da0073e9SAndroid Build Coastguard Workerscipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:6} | N/A 5953*da0073e9SAndroid Build Coastguard Worker---(input size: {m:4}, eigenpairs:{k:2}, units: relative error, maxiter={niter:4})--- 5954*da0073e9SAndroid Build Coastguard Worker''') 5955*da0073e9SAndroid Build Coastguard Worker 5956*da0073e9SAndroid Build Coastguard Worker def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None): 5957*da0073e9SAndroid Build Coastguard Worker dtype = t.dtype 5958*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype 5959*da0073e9SAndroid Build Coastguard Worker if dtype in {torch.bfloat16, torch.half}: 5960*da0073e9SAndroid Build Coastguard Worker numpy_dtype = torch.float 5961*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 5962*da0073e9SAndroid Build Coastguard Worker alpha = 0.9 + 0.3j if alpha is None else alpha 5963*da0073e9SAndroid Build Coastguard Worker beta = 0.5 + 0.6j if beta is None else beta 5964*da0073e9SAndroid Build Coastguard Worker else: 5965*da0073e9SAndroid Build Coastguard Worker alpha = 1.2 if alpha is None else alpha 5966*da0073e9SAndroid Build Coastguard Worker beta = 0.8 if beta is None else beta 5967*da0073e9SAndroid Build Coastguard Worker if activation == "gelu": 5968*da0073e9SAndroid Build Coastguard Worker res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) 5969*da0073e9SAndroid Build Coastguard Worker else: 5970*da0073e9SAndroid Build Coastguard Worker res1 = f(t, m, v, alpha=alpha, beta=beta) 5971*da0073e9SAndroid Build Coastguard Worker res2 = torch.full_like(res1, math.nan) 5972*da0073e9SAndroid Build Coastguard Worker if transpose_out: 5973*da0073e9SAndroid Build Coastguard Worker res2 = res2.t().clone(memory_format=torch.contiguous_format).t() 5974*da0073e9SAndroid Build Coastguard Worker if activation == "gelu": 5975*da0073e9SAndroid Build Coastguard Worker f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) 5976*da0073e9SAndroid Build Coastguard Worker else: 5977*da0073e9SAndroid Build Coastguard Worker f(t, m, v, alpha=alpha, beta=beta, out=res2) 5978*da0073e9SAndroid Build Coastguard Worker res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) 5979*da0073e9SAndroid Build Coastguard Worker if beta != 0: 5980*da0073e9SAndroid Build Coastguard Worker res3 += (beta * t).to(numpy_dtype).cpu().numpy() 5981*da0073e9SAndroid Build Coastguard Worker if activation == "relu": 5982*da0073e9SAndroid Build Coastguard Worker res3 = res3 * (res3 > 0) 5983*da0073e9SAndroid Build Coastguard Worker elif activation == "gelu": 5984*da0073e9SAndroid Build Coastguard Worker res3_t = torch.from_numpy(res3).to(dtype) 5985*da0073e9SAndroid Build Coastguard Worker approximate = "tanh" if t.is_cuda else "none" 5986*da0073e9SAndroid Build Coastguard Worker res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) 5987*da0073e9SAndroid Build Coastguard Worker res3 = res3_t.to(numpy_dtype).cpu().numpy() 5988*da0073e9SAndroid Build Coastguard Worker else: 5989*da0073e9SAndroid Build Coastguard Worker assert activation is None, f"unsupported activation {activation}" 5990*da0073e9SAndroid Build Coastguard Worker res3 = torch.from_numpy(res3).to(dtype) 5991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 5992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res3) 5993*da0073e9SAndroid Build Coastguard Worker 5994*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4, torch.double: 1e-8, 5995*da0073e9SAndroid Build Coastguard Worker torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 5996*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_and_complex_types_and( 5997*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [], 5998*da0073e9SAndroid Build Coastguard Worker torch.half)) 5999*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble) 6000*da0073e9SAndroid Build Coastguard Worker def test_addmv(self, device, dtype): 6001*da0073e9SAndroid Build Coastguard Worker if IS_ARM64 and device == 'cpu' and dtype == torch.float16: 6002*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438") 6003*da0073e9SAndroid Build Coastguard Worker # have to use torch.randn(...).to(bfloat16) instead of 6004*da0073e9SAndroid Build Coastguard Worker # torch.randn(..., dtype=bfloat16). randn does not support 6005*da0073e9SAndroid Build Coastguard Worker # bfloat16 yet. 6006*da0073e9SAndroid Build Coastguard Worker # "*0.2" to reduce errors for low precision 6007*da0073e9SAndroid Build Coastguard Worker ts = [ 6008*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn(50, device=device).to(dtype), 6009*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn(1, device=device).to(dtype).expand(50), 6010*da0073e9SAndroid Build Coastguard Worker ] 6011*da0073e9SAndroid Build Coastguard Worker vs = [ 6012*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn(100, device=device).to(dtype), 6013*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.ones(1, device=device).to(dtype).expand(100), # to reduce errors for low precision 6014*da0073e9SAndroid Build Coastguard Worker ] 6015*da0073e9SAndroid Build Coastguard Worker ms = [ 6016*da0073e9SAndroid Build Coastguard Worker # 0d 6017*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.ones((), device=device).to(dtype).expand(50, 100), # to reduce errors for low precision 6018*da0073e9SAndroid Build Coastguard Worker # 1d 6019*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), 6020*da0073e9SAndroid Build Coastguard Worker # this initialization reduces errors for low precision for broadcasted matrices 6021*da0073e9SAndroid Build Coastguard Worker # by making sure that intermediate and result values are exactly representable 6022*da0073e9SAndroid Build Coastguard Worker # in low precision type 6023*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device).to(dtype).expand(50, 100), 6024*da0073e9SAndroid Build Coastguard Worker # 2d 6025*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn((50, 100), device=device).to(dtype), 6026*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), 6027*da0073e9SAndroid Build Coastguard Worker ] 6028*da0073e9SAndroid Build Coastguard Worker for m, v, t in itertools.product(ms, vs, ts): 6029*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmv, t, m, v) 6030*da0073e9SAndroid Build Coastguard Worker # Test beta=0, t=nan 6031*da0073e9SAndroid Build Coastguard Worker t = torch.full((50,), math.nan, device=device).to(dtype) 6032*da0073e9SAndroid Build Coastguard Worker for m, v in itertools.product(ms, vs): 6033*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) 6034*da0073e9SAndroid Build Coastguard Worker 6035*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(*[torch.bfloat16] if TEST_WITH_ROCM or 6036*da0073e9SAndroid Build Coastguard Worker SM53OrLater else [])) 6037*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 6038*da0073e9SAndroid Build Coastguard Worker def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): 6039*da0073e9SAndroid Build Coastguard Worker # tests (o, s)*(s). o is output size, s is summed size. 6040*da0073e9SAndroid Build Coastguard Worker o = 5 6041*da0073e9SAndroid Build Coastguard Worker s = 3 6042*da0073e9SAndroid Build Coastguard Worker a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) 6043*da0073e9SAndroid Build Coastguard Worker x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) 6044*da0073e9SAndroid Build Coastguard Worker y_data = torch.ones(o, device=device, dtype=dtype) 6045*da0073e9SAndroid Build Coastguard Worker control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype) 6046*da0073e9SAndroid Build Coastguard Worker 6047*da0073e9SAndroid Build Coastguard Worker def _test(row_major, incx, incy, lda_tail): 6048*da0073e9SAndroid Build Coastguard Worker if row_major: 6049*da0073e9SAndroid Build Coastguard Worker a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype) 6050*da0073e9SAndroid Build Coastguard Worker else: 6051*da0073e9SAndroid Build Coastguard Worker a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0) 6052*da0073e9SAndroid Build Coastguard Worker a = a_storage[:o, :s].copy_(a_data) 6053*da0073e9SAndroid Build Coastguard Worker 6054*da0073e9SAndroid Build Coastguard Worker x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype) 6055*da0073e9SAndroid Build Coastguard Worker x = x_storage[:, 0].copy_(x_data) 6056*da0073e9SAndroid Build Coastguard Worker 6057*da0073e9SAndroid Build Coastguard Worker y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype) 6058*da0073e9SAndroid Build Coastguard Worker y = y_storage[:, 0].copy_(y_data) 6059*da0073e9SAndroid Build Coastguard Worker 6060*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmv, y, a, x) 6061*da0073e9SAndroid Build Coastguard Worker 6062*da0073e9SAndroid Build Coastguard Worker for row_major, incx, incy, lda_tail in itertools.product((False, True), (1, 2), (1, 2), (0, 1)): 6063*da0073e9SAndroid Build Coastguard Worker _test(row_major, incx, incy, lda_tail) 6064*da0073e9SAndroid Build Coastguard Worker 6065*da0073e9SAndroid Build Coastguard Worker def _test_addmm_impl(self, func, activation, device, dtype): 6066*da0073e9SAndroid Build Coastguard Worker M = torch.randn(10, 25, device=device).to(dtype) 6067*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 50, device=device).to(dtype) 6068*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device=device).to(dtype) 6069*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, M, m1, m2, activation=activation) 6070*da0073e9SAndroid Build Coastguard Worker 6071*da0073e9SAndroid Build Coastguard Worker # vector-shaped bias and beta=1 result in epilogue fusion in CUDA 6072*da0073e9SAndroid Build Coastguard Worker V = torch.randn(25, device=device).to(dtype) 6073*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) 6074*da0073e9SAndroid Build Coastguard Worker 6075*da0073e9SAndroid Build Coastguard Worker # Test 0-strided 6076*da0073e9SAndroid Build Coastguard Worker M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) 6077*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50) 6078*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device=device).to(dtype) 6079*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, M, m1, m2, activation=activation) 6080*da0073e9SAndroid Build Coastguard Worker 6081*da0073e9SAndroid Build Coastguard Worker # Test beta=0, M=nan 6082*da0073e9SAndroid Build Coastguard Worker M = torch.full((10, 25), math.nan, device=device).to(dtype) 6083*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 50, device=device).to(dtype) 6084*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device=device).to(dtype) 6085*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation) 6086*da0073e9SAndroid Build Coastguard Worker 6087*da0073e9SAndroid Build Coastguard Worker # Test transpose 6088*da0073e9SAndroid Build Coastguard Worker for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 6089*da0073e9SAndroid Build Coastguard Worker def maybe_transpose(cond, m): 6090*da0073e9SAndroid Build Coastguard Worker if not cond: 6091*da0073e9SAndroid Build Coastguard Worker return m 6092*da0073e9SAndroid Build Coastguard Worker return m.t().clone(memory_format=torch.contiguous_format).t() 6093*da0073e9SAndroid Build Coastguard Worker 6094*da0073e9SAndroid Build Coastguard Worker M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 6095*da0073e9SAndroid Build Coastguard Worker m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 6096*da0073e9SAndroid Build Coastguard Worker m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 6097*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation) 6098*da0073e9SAndroid Build Coastguard Worker 6099*da0073e9SAndroid Build Coastguard Worker if t1: 6100*da0073e9SAndroid Build Coastguard Worker # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) 6101*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,) 6102*da0073e9SAndroid Build Coastguard Worker 6103*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, 6104*da0073e9SAndroid Build Coastguard Worker torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6105*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.float32) 6106*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_and_complex_types_and( 6107*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) 6108*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6109*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.05) 6110*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.05) 6111*da0073e9SAndroid Build Coastguard Worker def test_addmm(self, device, dtype): 6112*da0073e9SAndroid Build Coastguard Worker self._test_addmm_impl(torch.addmm, None, device, dtype) 6113*da0073e9SAndroid Build Coastguard Worker 6114*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2, 6115*da0073e9SAndroid Build Coastguard Worker torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6116*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and( 6117*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) 6118*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.bfloat16)) 6119*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.05) 6120*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.05) 6121*da0073e9SAndroid Build Coastguard Worker def test_addmm_relu(self, device, dtype): 6122*da0073e9SAndroid Build Coastguard Worker self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) 6123*da0073e9SAndroid Build Coastguard Worker 6124*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6125*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 6126*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2, 6127*da0073e9SAndroid Build Coastguard Worker torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6128*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and( 6129*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) 6130*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.bfloat16)) 6131*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.05) 6132*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.05) 6133*da0073e9SAndroid Build Coastguard Worker def test_addmm_relu_tunableop_rocm(self, device, dtype): 6134*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(True) 6135*da0073e9SAndroid Build Coastguard Worker ordinal = torch.cuda.current_device() 6136*da0073e9SAndroid Build Coastguard Worker filename = f"tunableop_results{ordinal}.csv" 6137*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_filename(filename) 6138*da0073e9SAndroid Build Coastguard Worker iterations = torch.cuda.tunable.get_max_tuning_iterations() 6139*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(10) 6140*da0073e9SAndroid Build Coastguard Worker self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) 6141*da0073e9SAndroid Build Coastguard Worker # clean up, remove any file that was generated 6142*da0073e9SAndroid Build Coastguard Worker try: 6143*da0073e9SAndroid Build Coastguard Worker import os 6144*da0073e9SAndroid Build Coastguard Worker os.remove(filename) 6145*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 6146*da0073e9SAndroid Build Coastguard Worker pass 6147*da0073e9SAndroid Build Coastguard Worker # reset back to prior settings 6148*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.set_max_tuning_iterations(iterations) 6149*da0073e9SAndroid Build Coastguard Worker torch.cuda.tunable.enable(False) 6150*da0073e9SAndroid Build Coastguard Worker 6151*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2, 6152*da0073e9SAndroid Build Coastguard Worker torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 6153*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and( 6154*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) 6155*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.bfloat16)) 6156*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.05) 6157*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.05) 6158*da0073e9SAndroid Build Coastguard Worker def test_addmm_gelu(self, device, dtype): 6159*da0073e9SAndroid Build Coastguard Worker self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) 6160*da0073e9SAndroid Build Coastguard Worker 6161*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 6162*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_and_complex_types()) 6163*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 6164*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 6165*da0073e9SAndroid Build Coastguard Worker def test_addmm_sizes(self, device, dtype): 6166*da0073e9SAndroid Build Coastguard Worker for m in [0, 1, 25]: 6167*da0073e9SAndroid Build Coastguard Worker for n in [0, 1, 10]: 6168*da0073e9SAndroid Build Coastguard Worker for k in [0, 1, 8]: 6169*da0073e9SAndroid Build Coastguard Worker M = torch.randn(n, m, device=device).to(dtype) 6170*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(n, k, device=device).to(dtype) 6171*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(k, m, device=device).to(dtype) 6172*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmm, M, m1, m2) 6173*da0073e9SAndroid Build Coastguard Worker 6174*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(n, k + 1, device=device).to(dtype) 6175*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(k, m, device=device).to(dtype) 6176*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2)) 6177*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2)) 6178*da0073e9SAndroid Build Coastguard Worker 6179*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 6180*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6181*da0073e9SAndroid Build Coastguard Worker def test_addmm_baddbmm_overflow(self, device, dtype): 6182*da0073e9SAndroid Build Coastguard Worker orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction 6183*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 6184*da0073e9SAndroid Build Coastguard Worker inp = torch.zeros(128, 128, dtype=torch.half, device=device) 6185*da0073e9SAndroid Build Coastguard Worker mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100 6186*da0073e9SAndroid Build Coastguard Worker mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100 6187*da0073e9SAndroid Build Coastguard Worker out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.) 6188*da0073e9SAndroid Build Coastguard Worker # just check for no overflow on ROCM 6189*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 6190*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.isinf().any()) 6191*da0073e9SAndroid Build Coastguard Worker else: 6192*da0073e9SAndroid Build Coastguard Worker self.assertTrue((out == 10000.).all()) 6193*da0073e9SAndroid Build Coastguard Worker inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device) 6194*da0073e9SAndroid Build Coastguard Worker mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100 6195*da0073e9SAndroid Build Coastguard Worker mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100 6196*da0073e9SAndroid Build Coastguard Worker out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.) 6197*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 6198*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.isinf().any()) 6199*da0073e9SAndroid Build Coastguard Worker else: 6200*da0073e9SAndroid Build Coastguard Worker self.assertTrue((out == 10000.).all()) 6201*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig 6202*da0073e9SAndroid Build Coastguard Worker 6203*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 6204*da0073e9SAndroid Build Coastguard Worker def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): 6205*da0073e9SAndroid Build Coastguard Worker for shape in [[3, 2, 2], [2, 20, 20]]: 6206*da0073e9SAndroid Build Coastguard Worker mat1, mat2 = (torch.randn(shape, dtype=dtype, device=device) for _ in range(2)) 6207*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(shape, dtype=dtype, device=device), 6208*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)] 6209*da0073e9SAndroid Build Coastguard Worker outs = [None, torch.randn(shape, dtype=dtype, device=device), 6210*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)] 6211*da0073e9SAndroid Build Coastguard Worker options = itertools.product(inputs, outs) 6212*da0073e9SAndroid Build Coastguard Worker for input, out in options: 6213*da0073e9SAndroid Build Coastguard Worker y_ref = torch.bmm(mat1, mat2) 6214*da0073e9SAndroid Build Coastguard Worker y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) 6215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_ref, y) 6216*da0073e9SAndroid Build Coastguard Worker 6217*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64) 6218*da0073e9SAndroid Build Coastguard Worker def test_baddbmm_input_dtypes_compatibility(self, device, dtype): 6219*da0073e9SAndroid Build Coastguard Worker batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 6220*da0073e9SAndroid Build Coastguard Worker batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 6221*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.rand((1, 2, 2), device=device).to(dtype) 6222*da0073e9SAndroid Build Coastguard Worker if dtype != torch.float32: 6223*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"): 6224*da0073e9SAndroid Build Coastguard Worker y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0) 6225*da0073e9SAndroid Build Coastguard Worker else: 6226*da0073e9SAndroid Build Coastguard Worker out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan) 6227*da0073e9SAndroid Build Coastguard Worker y_ref = torch.bmm(batch1, batch2) 6228*da0073e9SAndroid Build Coastguard Worker y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) 6229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, y_ref) 6230*da0073e9SAndroid Build Coastguard Worker 6231*da0073e9SAndroid Build Coastguard Worker 6232*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6233*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6234*da0073e9SAndroid Build Coastguard Worker def test_matmul_45724(self, device): 6235*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/45724 6236*da0073e9SAndroid Build Coastguard Worker a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) 6237*da0073e9SAndroid Build Coastguard Worker b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) 6238*da0073e9SAndroid Build Coastguard Worker c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) 6239*da0073e9SAndroid Build Coastguard Worker cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half() 6240*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b, out=c) 6241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cpu_result) 6242*da0073e9SAndroid Build Coastguard Worker 6243*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6244*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90") 6245*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6246*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6247*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [16, 32]) 6248*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [16, 32]) 6249*da0073e9SAndroid Build Coastguard Worker @parametrize("use_transpose_a", [True, False]) 6250*da0073e9SAndroid Build Coastguard Worker @parametrize("use_transpose_b", [True, False]) 6251*da0073e9SAndroid Build Coastguard Worker def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b): 6252*da0073e9SAndroid Build Coastguard Worker def genf_int_float(x, y, use_transpose): 6253*da0073e9SAndroid Build Coastguard Worker if use_transpose: 6254*da0073e9SAndroid Build Coastguard Worker x, y = y, x 6255*da0073e9SAndroid Build Coastguard Worker x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device) 6256*da0073e9SAndroid Build Coastguard Worker x_float = x_int8.to(torch.float32) 6257*da0073e9SAndroid Build Coastguard Worker if use_transpose: 6258*da0073e9SAndroid Build Coastguard Worker return x_int8.t(), x_float.t() 6259*da0073e9SAndroid Build Coastguard Worker return x_int8, x_float 6260*da0073e9SAndroid Build Coastguard Worker 6261*da0073e9SAndroid Build Coastguard Worker def _test(m, k, n, transpose_a, transpose_b, test_equal=True): 6262*da0073e9SAndroid Build Coastguard Worker a_int8, a_float = genf_int_float(m, k, transpose_a) 6263*da0073e9SAndroid Build Coastguard Worker b_int8, b_float = genf_int_float(k, n, transpose_b) 6264*da0073e9SAndroid Build Coastguard Worker c_int32 = torch._int_mm(a_int8, b_int8) 6265*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c_int32.dtype is torch.int32) 6266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_int32.device, torch.device(device)) 6267*da0073e9SAndroid Build Coastguard Worker if test_equal: 6268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_int32.float(), torch.mm(a_float, b_float)) 6269*da0073e9SAndroid Build Coastguard Worker else: 6270*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(c_int32.float(), torch.mm(a_float, b_float)) 6271*da0073e9SAndroid Build Coastguard Worker c_int32_result = c_int32.new_empty(c_int32.size()) 6272*da0073e9SAndroid Build Coastguard Worker # Checking out variant 6273*da0073e9SAndroid Build Coastguard Worker torch._int_mm(a_int8, b_int8, out=c_int32_result) 6274*da0073e9SAndroid Build Coastguard Worker if test_equal: 6275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) 6276*da0073e9SAndroid Build Coastguard Worker else: 6277*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(c_int32_result.float(), torch.mm(a_float, b_float)) 6278*da0073e9SAndroid Build Coastguard Worker 6279*da0073e9SAndroid Build Coastguard Worker # NOTE: We're just exercising terrible failures here. 6280*da0073e9SAndroid Build Coastguard Worker version = _get_torch_cuda_version() 6281*da0073e9SAndroid Build Coastguard Worker SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) 6282*da0073e9SAndroid Build Coastguard Worker SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0) 6283*da0073e9SAndroid Build Coastguard Worker SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5) 6284*da0073e9SAndroid Build Coastguard Worker 6285*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 6286*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b, True) 6287*da0073e9SAndroid Build Coastguard Worker elif version >= (11, 7): 6288*da0073e9SAndroid Build Coastguard Worker if not use_transpose_a and use_transpose_b: 6289*da0073e9SAndroid Build Coastguard Worker if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)): 6290*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7)) 6291*da0073e9SAndroid Build Coastguard Worker else: 6292*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 6293*da0073e9SAndroid Build Coastguard Worker "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6294*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b) 6295*da0073e9SAndroid Build Coastguard Worker 6296*da0073e9SAndroid Build Coastguard Worker if use_transpose_a and not use_transpose_b: 6297*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 6298*da0073e9SAndroid Build Coastguard Worker "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6299*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b) 6300*da0073e9SAndroid Build Coastguard Worker 6301*da0073e9SAndroid Build Coastguard Worker if use_transpose_a and use_transpose_b: 6302*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 6303*da0073e9SAndroid Build Coastguard Worker "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6304*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b) 6305*da0073e9SAndroid Build Coastguard Worker 6306*da0073e9SAndroid Build Coastguard Worker if not use_transpose_a and not use_transpose_b: 6307*da0073e9SAndroid Build Coastguard Worker if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)): 6308*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b) 6309*da0073e9SAndroid Build Coastguard Worker else: 6310*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 6311*da0073e9SAndroid Build Coastguard Worker "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"): 6312*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b) 6313*da0073e9SAndroid Build Coastguard Worker else: 6314*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "_int_mm_out_cuda not compiled for CUDA"): 6315*da0073e9SAndroid Build Coastguard Worker _test(17, k, n, use_transpose_a, use_transpose_b, False) 6316*da0073e9SAndroid Build Coastguard Worker 6317*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6318*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6319*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6320*da0073e9SAndroid Build Coastguard Worker def test__int_mm_errors(self, device): 6321*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 6322*da0073e9SAndroid Build Coastguard Worker self.skipTest("_int_mm not compiled for ROCM") 6323*da0073e9SAndroid Build Coastguard Worker 6324*da0073e9SAndroid Build Coastguard Worker version = _get_torch_cuda_version() 6325*da0073e9SAndroid Build Coastguard Worker if version < (11, 7): 6326*da0073e9SAndroid Build Coastguard Worker self.skipTest("_int_mm only compiled for CUDA 11.7") 6327*da0073e9SAndroid Build Coastguard Worker 6328*da0073e9SAndroid Build Coastguard Worker def genf_int(x, y): 6329*da0073e9SAndroid Build Coastguard Worker return torch.empty((x, y), dtype=torch.int8, device=device) 6330*da0073e9SAndroid Build Coastguard Worker 6331*da0073e9SAndroid Build Coastguard Worker def _gen_pair(m, k, n): 6332*da0073e9SAndroid Build Coastguard Worker return genf_int(m, k), genf_int(k, n) 6333*da0073e9SAndroid Build Coastguard Worker 6334*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6335*da0073e9SAndroid Build Coastguard Worker r"self.size\(0\) needs to be greater than 16, but got 16", 6336*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(*_gen_pair(16, 8, 32))) 6337*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6338*da0073e9SAndroid Build Coastguard Worker r"self.size\(1\) needs to be greater than 0 and a multiple of 8, but got 7", 6339*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(*_gen_pair(17, 7, 32))) 6340*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6341*da0073e9SAndroid Build Coastguard Worker r"self.size\(1\) needs to match mat2.size\(0\) but got 8 and 7", 6342*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(genf_int(17, 8), genf_int(7, 32))) 6343*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6344*da0073e9SAndroid Build Coastguard Worker r"mat2.size\(1\) needs to be greater than 0 and a multiple of 8, but got 31", 6345*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(*_gen_pair(17, 8, 31))) 6346*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6347*da0073e9SAndroid Build Coastguard Worker r"expected scalar type Char but found Float", 6348*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(genf_int(17, 8).float(), genf_int(8, 32))) 6349*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6350*da0073e9SAndroid Build Coastguard Worker r"expected scalar type Char but found Float", 6351*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32).float())) 6352*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6353*da0073e9SAndroid Build Coastguard Worker r"Expected result dtype to be of type kInt but got float", 6354*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 32).float())) 6355*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6356*da0073e9SAndroid Build Coastguard Worker r"Expected result.size\(0\) to be 17 but got 15", 6357*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(15, 32).int())) 6358*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 6359*da0073e9SAndroid Build Coastguard Worker r"Expected result.size\(0\) to be 17 but got 16", 6360*da0073e9SAndroid Build Coastguard Worker lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int())) 6361*da0073e9SAndroid Build Coastguard Worker 6362*da0073e9SAndroid Build Coastguard Worker @onlyCPU 6363*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [0, 8, 17]) 6364*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [0, 16, 32]) 6365*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [16, 32]) 6366*da0073e9SAndroid Build Coastguard Worker @parametrize("use_transpose_a", [True, False]) 6367*da0073e9SAndroid Build Coastguard Worker @parametrize("use_transpose_b", [True, False]) 6368*da0073e9SAndroid Build Coastguard Worker @parametrize("non_contig_type", [0, 1, 2]) 6369*da0073e9SAndroid Build Coastguard Worker def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type): 6370*da0073e9SAndroid Build Coastguard Worker # non_contig_type: 6371*da0073e9SAndroid Build Coastguard Worker # 0: the whole data buffer is contiguous (can be transposed) 6372*da0073e9SAndroid Build Coastguard Worker # 1: stride of one dimension is 1, but the whole buffer is not contiguous 6373*da0073e9SAndroid Build Coastguard Worker # 2: Neither stride is 1 6374*da0073e9SAndroid Build Coastguard Worker 6375*da0073e9SAndroid Build Coastguard Worker def genf_int_float(x, y, use_transpose, non_contig_type): 6376*da0073e9SAndroid Build Coastguard Worker if use_transpose: 6377*da0073e9SAndroid Build Coastguard Worker x, y = y, x 6378*da0073e9SAndroid Build Coastguard Worker if non_contig_type != 0: 6379*da0073e9SAndroid Build Coastguard Worker y = y * 2 6380*da0073e9SAndroid Build Coastguard Worker x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device) 6381*da0073e9SAndroid Build Coastguard Worker x_float = x_int8.to(torch.float32) 6382*da0073e9SAndroid Build Coastguard Worker if non_contig_type == 1: 6383*da0073e9SAndroid Build Coastguard Worker x_int8 = x_int8[:, : y // 2] 6384*da0073e9SAndroid Build Coastguard Worker x_float = x_float[:, : y // 2] 6385*da0073e9SAndroid Build Coastguard Worker elif non_contig_type == 2: 6386*da0073e9SAndroid Build Coastguard Worker x_int8 = x_int8[:, ::2] 6387*da0073e9SAndroid Build Coastguard Worker x_float = x_float[:, ::2] 6388*da0073e9SAndroid Build Coastguard Worker if use_transpose: 6389*da0073e9SAndroid Build Coastguard Worker return x_int8.t(), x_float.t() 6390*da0073e9SAndroid Build Coastguard Worker return x_int8, x_float 6391*da0073e9SAndroid Build Coastguard Worker 6392*da0073e9SAndroid Build Coastguard Worker if non_contig_type != 0 and (m == 0 or k == 0): 6393*da0073e9SAndroid Build Coastguard Worker return 6394*da0073e9SAndroid Build Coastguard Worker a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type) 6395*da0073e9SAndroid Build Coastguard Worker b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type) 6396*da0073e9SAndroid Build Coastguard Worker c_int32 = torch._int_mm(a_int8, b_int8) 6397*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c_int32.dtype is torch.int32) 6398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_int32.device, torch.device(device)) 6399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_int32.float(), torch.mm(a_float, b_float)) 6400*da0073e9SAndroid Build Coastguard Worker c_int32_result = c_int32.new_empty(c_int32.size()) 6401*da0073e9SAndroid Build Coastguard Worker # Checking out variant 6402*da0073e9SAndroid Build Coastguard Worker torch._int_mm(a_int8, b_int8, out=c_int32_result) 6403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) 6404*da0073e9SAndroid Build Coastguard Worker 6405*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6406*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6407*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6408*da0073e9SAndroid Build Coastguard Worker def test__convert_weight_to_int4pack(self, device): 6409*da0073e9SAndroid Build Coastguard Worker # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead 6410*da0073e9SAndroid Build Coastguard Worker test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)] 6411*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and not SM80OrLater: 6412*da0073e9SAndroid Build Coastguard Worker self.skipTest("requires SM80 or later") 6413*da0073e9SAndroid Build Coastguard Worker 6414*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 6415*da0073e9SAndroid Build Coastguard Worker if not CDNA2OrLater(): 6416*da0073e9SAndroid Build Coastguard Worker self.skipTest("_int4_mm is supported only for CDNA2 or later") 6417*da0073e9SAndroid Build Coastguard Worker 6418*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 6419*da0073e9SAndroid Build Coastguard Worker for shape, innerKTiles in test_list: 6420*da0073e9SAndroid Build Coastguard Worker b = torch.rand(shape, dtype=torch.bfloat16, device=device) 6421*da0073e9SAndroid Build Coastguard Worker b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32) 6422*da0073e9SAndroid Build Coastguard Worker b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles) 6423*da0073e9SAndroid Build Coastguard Worker b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles) 6424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape) 6425*da0073e9SAndroid Build Coastguard Worker 6426*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6427*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6428*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6429*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [32, 64]) 6430*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [32, 64]) 6431*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [48, 64]) 6432*da0073e9SAndroid Build Coastguard Worker def test__int4_mm(self, device, m, k, n): 6433*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and not SM80OrLater: 6434*da0073e9SAndroid Build Coastguard Worker self.skipTest("requires SM80 or later") 6435*da0073e9SAndroid Build Coastguard Worker 6436*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 6437*da0073e9SAndroid Build Coastguard Worker if not CDNA2OrLater(): 6438*da0073e9SAndroid Build Coastguard Worker self.skipTest("_int4_mm is supported only for CDNA2 or later") 6439*da0073e9SAndroid Build Coastguard Worker 6440*da0073e9SAndroid Build Coastguard Worker q_group = 32 6441*da0073e9SAndroid Build Coastguard Worker inner_k_tiles = 2 6442*da0073e9SAndroid Build Coastguard Worker 6443*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 6444*da0073e9SAndroid Build Coastguard Worker a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6445*da0073e9SAndroid Build Coastguard Worker b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) 6446*da0073e9SAndroid Build Coastguard Worker 6447*da0073e9SAndroid Build Coastguard Worker def convert_weight_to_int4pack(b): 6448*da0073e9SAndroid Build Coastguard Worker b_uint8, b_scales_and_zeros = _group_quantize_tensor( 6449*da0073e9SAndroid Build Coastguard Worker b, n_bit=4, q_group_size=q_group 6450*da0073e9SAndroid Build Coastguard Worker ) 6451*da0073e9SAndroid Build Coastguard Worker b_int4pack = torch._convert_weight_to_int4pack( 6452*da0073e9SAndroid Build Coastguard Worker b_uint8, inner_k_tiles 6453*da0073e9SAndroid Build Coastguard Worker ) 6454*da0073e9SAndroid Build Coastguard Worker 6455*da0073e9SAndroid Build Coastguard Worker return b_int4pack, b_scales_and_zeros 6456*da0073e9SAndroid Build Coastguard Worker 6457*da0073e9SAndroid Build Coastguard Worker def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): 6458*da0073e9SAndroid Build Coastguard Worker return torch._weight_int4pack_mm( 6459*da0073e9SAndroid Build Coastguard Worker a, b_int4pack, q_group, b_scales_and_zeros 6460*da0073e9SAndroid Build Coastguard Worker ) 6461*da0073e9SAndroid Build Coastguard Worker 6462*da0073e9SAndroid Build Coastguard Worker b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) 6463*da0073e9SAndroid Build Coastguard Worker 6464*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []): 6465*da0073e9SAndroid Build Coastguard Worker a = a_bf16.to(dtype=dtype) 6466*da0073e9SAndroid Build Coastguard Worker b = b_bf16.to(dtype=dtype) 6467*da0073e9SAndroid Build Coastguard Worker b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) 6468*da0073e9SAndroid Build Coastguard Worker ref = torch.mm(a, b) 6469*da0073e9SAndroid Build Coastguard Worker res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) 6470*da0073e9SAndroid Build Coastguard Worker 6471*da0073e9SAndroid Build Coastguard Worker mean_err = ((res - ref).abs() / ref).mean() 6472*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mean_err < 0.05) 6473*da0073e9SAndroid Build Coastguard Worker 6474*da0073e9SAndroid Build Coastguard Worker 6475*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 6476*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 6477*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6478*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [32, 64]) 6479*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [32, 64]) 6480*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [48, 64]) 6481*da0073e9SAndroid Build Coastguard Worker def test_compile_int4_mm(self, device, m, k, n): 6482*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and not SM80OrLater: 6483*da0073e9SAndroid Build Coastguard Worker self.skipTest("requires SM80 or later") 6484*da0073e9SAndroid Build Coastguard Worker 6485*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 6486*da0073e9SAndroid Build Coastguard Worker if not CDNA2OrLater(): 6487*da0073e9SAndroid Build Coastguard Worker self.skipTest("_int4_mm is supported only for CDNA2 or later") 6488*da0073e9SAndroid Build Coastguard Worker 6489*da0073e9SAndroid Build Coastguard Worker q_group = 32 6490*da0073e9SAndroid Build Coastguard Worker inner_k_tiles = 2 6491*da0073e9SAndroid Build Coastguard Worker 6492*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 6493*da0073e9SAndroid Build Coastguard Worker a = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6494*da0073e9SAndroid Build Coastguard Worker b = torch.rand((k, n), dtype=torch.bfloat16, device=device) 6495*da0073e9SAndroid Build Coastguard Worker 6496*da0073e9SAndroid Build Coastguard Worker b_int32, b_scales_and_zeros = _group_quantize_tensor( 6497*da0073e9SAndroid Build Coastguard Worker b, n_bit=4, q_group_size=q_group 6498*da0073e9SAndroid Build Coastguard Worker ) 6499*da0073e9SAndroid Build Coastguard Worker 6500*da0073e9SAndroid Build Coastguard Worker @torch.compile 6501*da0073e9SAndroid Build Coastguard Worker def int4_mm(a, b_int32, b_scales_and_zeros): 6502*da0073e9SAndroid Build Coastguard Worker b_int4pack = torch._convert_weight_to_int4pack( 6503*da0073e9SAndroid Build Coastguard Worker b_int32, inner_k_tiles 6504*da0073e9SAndroid Build Coastguard Worker ) 6505*da0073e9SAndroid Build Coastguard Worker return torch._weight_int4pack_mm( 6506*da0073e9SAndroid Build Coastguard Worker a, b_int4pack, q_group, b_scales_and_zeros 6507*da0073e9SAndroid Build Coastguard Worker ) 6508*da0073e9SAndroid Build Coastguard Worker 6509*da0073e9SAndroid Build Coastguard Worker res = int4_mm(a, b_int32, b_scales_and_zeros) 6510*da0073e9SAndroid Build Coastguard Worker ref = torch.mm(a, b) 6511*da0073e9SAndroid Build Coastguard Worker 6512*da0073e9SAndroid Build Coastguard Worker mean_err = ((res - ref).abs() / ref).mean() 6513*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mean_err < 0.05) 6514*da0073e9SAndroid Build Coastguard Worker 6515*da0073e9SAndroid Build Coastguard Worker @onlyCPU 6516*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [32, 64]) 6517*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [32, 64]) 6518*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [48, 64]) 6519*da0073e9SAndroid Build Coastguard Worker def test__int8_mm(self, device, m, k, n): 6520*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 6521*da0073e9SAndroid Build Coastguard Worker a = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6522*da0073e9SAndroid Build Coastguard Worker b = torch.rand((n, k), dtype=torch.bfloat16, device=device) 6523*da0073e9SAndroid Build Coastguard Worker 6524*da0073e9SAndroid Build Coastguard Worker def convert_weight_to_int8pack(b): 6525*da0073e9SAndroid Build Coastguard Worker b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( 6526*da0073e9SAndroid Build Coastguard Worker b, -128, 127, torch.int8 6527*da0073e9SAndroid Build Coastguard Worker ) 6528*da0073e9SAndroid Build Coastguard Worker return b_int8pack, b_scales 6529*da0073e9SAndroid Build Coastguard Worker 6530*da0073e9SAndroid Build Coastguard Worker def weight_int8pack_mm(a, b_int8pack, b_scales): 6531*da0073e9SAndroid Build Coastguard Worker return torch._weight_int8pack_mm( 6532*da0073e9SAndroid Build Coastguard Worker a, b_int8pack, b_scales 6533*da0073e9SAndroid Build Coastguard Worker ) 6534*da0073e9SAndroid Build Coastguard Worker 6535*da0073e9SAndroid Build Coastguard Worker b_int8pack, b_scales = convert_weight_to_int8pack(b) 6536*da0073e9SAndroid Build Coastguard Worker res = weight_int8pack_mm(a, b_int8pack, b_scales) 6537*da0073e9SAndroid Build Coastguard Worker ref = torch.mm(a, b.transpose(0, 1)) 6538*da0073e9SAndroid Build Coastguard Worker 6539*da0073e9SAndroid Build Coastguard Worker mean_err = ((res - ref).abs() / ref).mean() 6540*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mean_err < 0.05) 6541*da0073e9SAndroid Build Coastguard Worker 6542*da0073e9SAndroid Build Coastguard Worker @onlyCPU 6543*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [32, 64]) 6544*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [32, 64]) 6545*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [48, 64]) 6546*da0073e9SAndroid Build Coastguard Worker def test_compile_int8_mm(self, device, m, k, n): 6547*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 6548*da0073e9SAndroid Build Coastguard Worker a = torch.rand((m, k), dtype=torch.bfloat16, device=device) 6549*da0073e9SAndroid Build Coastguard Worker b = torch.rand((n, k), dtype=torch.bfloat16, device=device) 6550*da0073e9SAndroid Build Coastguard Worker 6551*da0073e9SAndroid Build Coastguard Worker b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( 6552*da0073e9SAndroid Build Coastguard Worker b, -128, 127, torch.int8 6553*da0073e9SAndroid Build Coastguard Worker ) 6554*da0073e9SAndroid Build Coastguard Worker 6555*da0073e9SAndroid Build Coastguard Worker @torch.compile 6556*da0073e9SAndroid Build Coastguard Worker def int8_mm(a, b_int8pack, b_scales): 6557*da0073e9SAndroid Build Coastguard Worker return torch._weight_int8pack_mm( 6558*da0073e9SAndroid Build Coastguard Worker a, b_int8pack, b_scales 6559*da0073e9SAndroid Build Coastguard Worker ) 6560*da0073e9SAndroid Build Coastguard Worker 6561*da0073e9SAndroid Build Coastguard Worker res = int8_mm(a, b_int8pack, b_scales) 6562*da0073e9SAndroid Build Coastguard Worker ref = torch.mm(a, b.transpose(0, 1)) 6563*da0073e9SAndroid Build Coastguard Worker 6564*da0073e9SAndroid Build Coastguard Worker mean_err = ((res - ref).abs() / ref).mean() 6565*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mean_err < 0.05) 6566*da0073e9SAndroid Build Coastguard Worker 6567*da0073e9SAndroid Build Coastguard Worker @onlyCPU 6568*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [32, 35, 36, 40, 64]) 6569*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [32, 35, 36, 40, 64]) 6570*da0073e9SAndroid Build Coastguard Worker # NOTE: This is intended to cover fp16_gemv_trans in 6571*da0073e9SAndroid Build Coastguard Worker # BlasKernel.cpp. Currently, bounds being divisible by 32, 8-but-not-32, and 4-but-not-8 6572*da0073e9SAndroid Build Coastguard Worker # all matter. 6573*da0073e9SAndroid Build Coastguard Worker def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k): 6574*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 6575*da0073e9SAndroid Build Coastguard Worker a = torch.rand((m, k), dtype=torch.half, device=device) 6576*da0073e9SAndroid Build Coastguard Worker b = torch.rand((1, k), dtype=torch.half, device=device) 6577*da0073e9SAndroid Build Coastguard Worker 6578*da0073e9SAndroid Build Coastguard Worker prev = torch._C._get_cpu_allow_fp16_reduced_precision_reduction() 6579*da0073e9SAndroid Build Coastguard Worker try: 6580*da0073e9SAndroid Build Coastguard Worker torch._C._set_cpu_allow_fp16_reduced_precision_reduction(False) 6581*da0073e9SAndroid Build Coastguard Worker ref = torch.mm(a, b.t()) 6582*da0073e9SAndroid Build Coastguard Worker try: 6583*da0073e9SAndroid Build Coastguard Worker torch._C._set_cpu_allow_fp16_reduced_precision_reduction(True) 6584*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 6585*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest from e 6586*da0073e9SAndroid Build Coastguard Worker res = torch.mm(a, b.t()) 6587*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2) 6588*da0073e9SAndroid Build Coastguard Worker finally: 6589*da0073e9SAndroid Build Coastguard Worker torch._C._set_cpu_allow_fp16_reduced_precision_reduction(prev) 6590*da0073e9SAndroid Build Coastguard Worker 6591*da0073e9SAndroid Build Coastguard Worker @slowTest 6592*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6593*da0073e9SAndroid Build Coastguard Worker # bfloat16 doesn't have sufficient precision to pass this test 6594*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble) 6595*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) 6596*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.01) 6597*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.01) 6598*da0073e9SAndroid Build Coastguard Worker def test_mm(self, device, dtype): 6599*da0073e9SAndroid Build Coastguard Worker def _test_mm(n, m, p, dtype, genf): 6600*da0073e9SAndroid Build Coastguard Worker # helper function 6601*da0073e9SAndroid Build Coastguard Worker def matrixmultiply(mat1, mat2): 6602*da0073e9SAndroid Build Coastguard Worker n = mat1.size(0) 6603*da0073e9SAndroid Build Coastguard Worker m = mat1.size(1) 6604*da0073e9SAndroid Build Coastguard Worker p = mat2.size(1) 6605*da0073e9SAndroid Build Coastguard Worker dtype_ = torch.float if dtype == torch.half else dtype 6606*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 6607*da0073e9SAndroid Build Coastguard Worker mat1 = mat1.float() 6608*da0073e9SAndroid Build Coastguard Worker mat2 = mat2.float() 6609*da0073e9SAndroid Build Coastguard Worker res = torch.zeros(n, p, dtype=dtype_, device=device) 6610*da0073e9SAndroid Build Coastguard Worker for i, j in iter_indices(res): 6611*da0073e9SAndroid Build Coastguard Worker res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) 6612*da0073e9SAndroid Build Coastguard Worker return res.half() if dtype == torch.half else res 6613*da0073e9SAndroid Build Coastguard Worker 6614*da0073e9SAndroid Build Coastguard Worker # contiguous case 6615*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 6616*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, p) 6617*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 6618*da0073e9SAndroid Build Coastguard Worker 6619*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 6620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 6621*da0073e9SAndroid Build Coastguard Worker 6622*da0073e9SAndroid Build Coastguard Worker # non contiguous case 1 6623*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 6624*da0073e9SAndroid Build Coastguard Worker mat2 = genf(p, m).t() 6625*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 6626*da0073e9SAndroid Build Coastguard Worker 6627*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 6628*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 6629*da0073e9SAndroid Build Coastguard Worker 6630*da0073e9SAndroid Build Coastguard Worker # non contiguous case 2 6631*da0073e9SAndroid Build Coastguard Worker mat1 = genf(m, n).t() 6632*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, p) 6633*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 6634*da0073e9SAndroid Build Coastguard Worker 6635*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 6636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 6637*da0073e9SAndroid Build Coastguard Worker 6638*da0073e9SAndroid Build Coastguard Worker # non contiguous case 3 6639*da0073e9SAndroid Build Coastguard Worker mat1 = genf(m, n).t() 6640*da0073e9SAndroid Build Coastguard Worker mat2 = genf(p, m).t() 6641*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 6642*da0073e9SAndroid Build Coastguard Worker 6643*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 6644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 6645*da0073e9SAndroid Build Coastguard Worker 6646*da0073e9SAndroid Build Coastguard Worker # test with zero stride 6647*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 6648*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, 1).expand(m, p) 6649*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 6650*da0073e9SAndroid Build Coastguard Worker 6651*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 6652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 6653*da0073e9SAndroid Build Coastguard Worker 6654*da0073e9SAndroid Build Coastguard Worker # explicitly exercise the _out variant in torch.mm(). 6655*da0073e9SAndroid Build Coastguard Worker # contiguous case 6656*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 6657*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, p) 6658*da0073e9SAndroid Build Coastguard Worker res = genf(n, p) 6659*da0073e9SAndroid Build Coastguard Worker torch.mm(mat1, mat2, out=res) 6660*da0073e9SAndroid Build Coastguard Worker 6661*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 6662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 6663*da0073e9SAndroid Build Coastguard Worker 6664*da0073e9SAndroid Build Coastguard Worker # explicitly exercise the _out variant in torch.mm(). 6665*da0073e9SAndroid Build Coastguard Worker # non contiguous case 3 6666*da0073e9SAndroid Build Coastguard Worker mat1 = genf(m, n).t() 6667*da0073e9SAndroid Build Coastguard Worker mat2 = genf(p, m).t() 6668*da0073e9SAndroid Build Coastguard Worker res = genf(n, p) 6669*da0073e9SAndroid Build Coastguard Worker torch.mm(mat1, mat2, out=res) 6670*da0073e9SAndroid Build Coastguard Worker 6671*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 6672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 6673*da0073e9SAndroid Build Coastguard Worker 6674*da0073e9SAndroid Build Coastguard Worker def genf_int(x, y): 6675*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 100, (x, y), dtype=dtype, device=device) 6676*da0073e9SAndroid Build Coastguard Worker 6677*da0073e9SAndroid Build Coastguard Worker def genf_bfloat(x, y): 6678*da0073e9SAndroid Build Coastguard Worker return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 6679*da0073e9SAndroid Build Coastguard Worker 6680*da0073e9SAndroid Build Coastguard Worker def genf_float(x, y): 6681*da0073e9SAndroid Build Coastguard Worker return torch.randn(x, y, dtype=dtype, device=device) 6682*da0073e9SAndroid Build Coastguard Worker 6683*da0073e9SAndroid Build Coastguard Worker def genf_Half(x, y): 6684*da0073e9SAndroid Build Coastguard Worker return torch.randn(x, y, dtype=dtype, device=device) 6685*da0073e9SAndroid Build Coastguard Worker 6686*da0073e9SAndroid Build Coastguard Worker for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: 6687*da0073e9SAndroid Build Coastguard Worker if (dtype == torch.int32) or (dtype == torch.int64): 6688*da0073e9SAndroid Build Coastguard Worker genf = genf_int 6689*da0073e9SAndroid Build Coastguard Worker elif (dtype == torch.bfloat16): 6690*da0073e9SAndroid Build Coastguard Worker genf = genf_bfloat 6691*da0073e9SAndroid Build Coastguard Worker elif (dtype == torch.half): 6692*da0073e9SAndroid Build Coastguard Worker genf = genf_Half 6693*da0073e9SAndroid Build Coastguard Worker else: 6694*da0073e9SAndroid Build Coastguard Worker genf = genf_float 6695*da0073e9SAndroid Build Coastguard Worker 6696*da0073e9SAndroid Build Coastguard Worker _test_mm(n, m, p, dtype, genf) 6697*da0073e9SAndroid Build Coastguard Worker 6698*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6699*da0073e9SAndroid Build Coastguard Worker def test_mm_bmm_non_memory_dense(self, device): 6700*da0073e9SAndroid Build Coastguard Worker def _slice(tensor, fn): 6701*da0073e9SAndroid Build Coastguard Worker return fn(tensor)[..., ::2] 6702*da0073e9SAndroid Build Coastguard Worker A = torch.randn(3, 6, dtype=torch.cfloat, device=device) 6703*da0073e9SAndroid Build Coastguard Worker B = torch.randn(3, 3, dtype=torch.cfloat, device=device) 6704*da0073e9SAndroid Build Coastguard Worker out = torch.empty(3, 3, device=device, dtype=torch.complex64).t() 6705*da0073e9SAndroid Build Coastguard Worker out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t() 6706*da0073e9SAndroid Build Coastguard Worker A_conj = _slice(A, torch.conj) 6707*da0073e9SAndroid Build Coastguard Worker A_conj_physical = _slice(A, torch.conj_physical) 6708*da0073e9SAndroid Build Coastguard Worker 6709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out)) 6710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out)) 6711*da0073e9SAndroid Build Coastguard Worker 6712*da0073e9SAndroid Build Coastguard Worker Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device) 6713*da0073e9SAndroid Build Coastguard Worker Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device) 6714*da0073e9SAndroid Build Coastguard Worker Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3) 6715*da0073e9SAndroid Build Coastguard Worker out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).mT 6716*da0073e9SAndroid Build Coastguard Worker 6717*da0073e9SAndroid Build Coastguard Worker Ab_conj = _slice(Ab, torch.conj) 6718*da0073e9SAndroid Build Coastguard Worker Ab_conj_physical = _slice(Ab, torch.conj_physical) 6719*da0073e9SAndroid Build Coastguard Worker 6720*da0073e9SAndroid Build Coastguard Worker def t_b(tensor): 6721*da0073e9SAndroid Build Coastguard Worker return tensor.mT 6722*da0073e9SAndroid Build Coastguard Worker 6723*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b)) 6724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b)) 6725*da0073e9SAndroid Build Coastguard Worker 6726*da0073e9SAndroid Build Coastguard Worker # test broadcasting 6727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b)) 6728*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b)) 6729*da0073e9SAndroid Build Coastguard Worker 6730*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6731*da0073e9SAndroid Build Coastguard Worker def test_mm_conjtranspose(self, device): 6732*da0073e9SAndroid Build Coastguard Worker A = torch.randn(3, 3, dtype=torch.cfloat, device=device) 6733*da0073e9SAndroid Build Coastguard Worker B = torch.randn(3, 3, dtype=torch.cfloat, device=device) 6734*da0073e9SAndroid Build Coastguard Worker 6735*da0073e9SAndroid Build Coastguard Worker # A conjtranspose 6736*da0073e9SAndroid Build Coastguard Worker out1 = torch.mm(A.t().conj(), B) 6737*da0073e9SAndroid Build Coastguard Worker out1_ref = torch.mm(A.t().conj_physical(), B) 6738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out1_ref) 6739*da0073e9SAndroid Build Coastguard Worker 6740*da0073e9SAndroid Build Coastguard Worker # B conjtranspose 6741*da0073e9SAndroid Build Coastguard Worker out1 = torch.mm(A, B.t().conj()) 6742*da0073e9SAndroid Build Coastguard Worker out1_ref = torch.mm(A, B.t().conj_physical()) 6743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out1_ref) 6744*da0073e9SAndroid Build Coastguard Worker 6745*da0073e9SAndroid Build Coastguard Worker # A&B conjtranspose 6746*da0073e9SAndroid Build Coastguard Worker out1 = torch.mm(A.t().conj(), B.t().conj()) 6747*da0073e9SAndroid Build Coastguard Worker out1_ref = torch.mm(A.t().conj_physical(), B.t().conj_physical()) 6748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out1_ref) 6749*da0073e9SAndroid Build Coastguard Worker 6750*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6751*da0073e9SAndroid Build Coastguard Worker def test_mm_empty_inputs_mixed_dtype_errors(self, device): 6752*da0073e9SAndroid Build Coastguard Worker a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device) 6753*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, 20, dtype=torch.float32, device=device) 6754*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected .* and .* to have the same dtype, but got:"): 6755*da0073e9SAndroid Build Coastguard Worker torch.mm(a, b) 6756*da0073e9SAndroid Build Coastguard Worker 6757*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6758*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 6759*da0073e9SAndroid Build Coastguard Worker def test_strided_mm_bmm(self, device, dtype): 6760*da0073e9SAndroid Build Coastguard Worker # Tests strided view case with stride smaller than corresponding dimension size 6761*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device) 6762*da0073e9SAndroid Build Coastguard Worker new_shape = [2, 2, 2] 6763*da0073e9SAndroid Build Coastguard Worker new_stride = [3, 1, 1] 6764*da0073e9SAndroid Build Coastguard Worker sx = torch.as_strided(x, size=new_shape, stride=new_stride) 6765*da0073e9SAndroid Build Coastguard Worker 6766*da0073e9SAndroid Build Coastguard Worker torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 6767*da0073e9SAndroid Build Coastguard Worker np_fn = lambda x: np.matmul(x, x) # noqa: E731 6768*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, sx) 6769*da0073e9SAndroid Build Coastguard Worker 6770*da0073e9SAndroid Build Coastguard Worker torch_fn = lambda x: torch.mm(x, x) # noqa: E731 6771*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, sx[0]) 6772*da0073e9SAndroid Build Coastguard Worker 6773*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 6774*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6775*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6776*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.05) 6777*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.05) 6778*da0073e9SAndroid Build Coastguard Worker def test_bmm(self, device, dtype): 6779*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: 6780*da0073e9SAndroid Build Coastguard Worker # cuBLAS does not guarantee BFloat16 support on SM < 53. 6781*da0073e9SAndroid Build Coastguard Worker # So on PyTorch, we consider BFloat16 support on SM < 53 as 6782*da0073e9SAndroid Build Coastguard Worker # undefined bahavior 6783*da0073e9SAndroid Build Coastguard Worker return 6784*da0073e9SAndroid Build Coastguard Worker 6785*da0073e9SAndroid Build Coastguard Worker batch_sizes = [1, 10] 6786*da0073e9SAndroid Build Coastguard Worker M, N, O = 23, 15, 12 6787*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 6788*da0073e9SAndroid Build Coastguard Worker 6789*da0073e9SAndroid Build Coastguard Worker is_supported = True 6790*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16 and self.device_type == 'cuda': 6791*da0073e9SAndroid Build Coastguard Worker is_supported = TEST_WITH_ROCM or SM53OrLater 6792*da0073e9SAndroid Build Coastguard Worker 6793*da0073e9SAndroid Build Coastguard Worker if not is_supported: 6794*da0073e9SAndroid Build Coastguard Worker for num_batches in batch_sizes: 6795*da0073e9SAndroid Build Coastguard Worker b1 = torch.randn(num_batches, M, N, device=device).to(dtype) 6796*da0073e9SAndroid Build Coastguard Worker b2 = torch.randn(num_batches, N, O, device=device).to(dtype) 6797*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 6798*da0073e9SAndroid Build Coastguard Worker lambda: torch.bmm(b1, b2)) 6799*da0073e9SAndroid Build Coastguard Worker return 6800*da0073e9SAndroid Build Coastguard Worker 6801*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 6802*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 6803*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2]) 6804*da0073e9SAndroid Build Coastguard Worker 6805*da0073e9SAndroid Build Coastguard Worker def generate_inputs(num_batches): 6806*da0073e9SAndroid Build Coastguard Worker # transposed tensors 6807*da0073e9SAndroid Build Coastguard Worker for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): 6808*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1) 6809*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1) 6810*da0073e9SAndroid Build Coastguard Worker b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 6811*da0073e9SAndroid Build Coastguard Worker b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 6812*da0073e9SAndroid Build Coastguard Worker yield b1, b2 6813*da0073e9SAndroid Build Coastguard Worker # broadcasting tensors 6814*da0073e9SAndroid Build Coastguard Worker for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): 6815*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) 6816*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) 6817*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor(shape1, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, M, N) 6818*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor(shape2, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, N, O) 6819*da0073e9SAndroid Build Coastguard Worker yield b1, b2 6820*da0073e9SAndroid Build Coastguard Worker # zero-sized tensors 6821*da0073e9SAndroid Build Coastguard Worker for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 6822*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 6823*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 6824*da0073e9SAndroid Build Coastguard Worker b1 = torch.randn(shape1, dtype=dtype, device=device) 6825*da0073e9SAndroid Build Coastguard Worker b2 = torch.randn(shape2, dtype=dtype, device=device) 6826*da0073e9SAndroid Build Coastguard Worker yield b1, b2 6827*da0073e9SAndroid Build Coastguard Worker 6828*da0073e9SAndroid Build Coastguard Worker for num_batches in batch_sizes: 6829*da0073e9SAndroid Build Coastguard Worker for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))): 6830*da0073e9SAndroid Build Coastguard Worker res1 = torch.bmm(b1, b2) 6831*da0073e9SAndroid Build Coastguard Worker res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \ 6832*da0073e9SAndroid Build Coastguard Worker .permute(perm3).contiguous().permute(invert_perm(perm3)) 6833*da0073e9SAndroid Build Coastguard Worker torch.bmm(b1, b2, out=res2) 6834*da0073e9SAndroid Build Coastguard Worker expect = torch.from_numpy( 6835*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 6836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res1) 6837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res2) 6838*da0073e9SAndroid Build Coastguard Worker 6839*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 6840*da0073e9SAndroid Build Coastguard Worker # check that mixed arguments are rejected 6841*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) 6842*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) 6843*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())) 6844*da0073e9SAndroid Build Coastguard Worker 6845*da0073e9SAndroid Build Coastguard Worker def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): 6846*da0073e9SAndroid Build Coastguard Worker getattr(out_tensor, func + "_")(b1, b2) 6847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, ref) 6848*da0073e9SAndroid Build Coastguard Worker res3 = out_tensor.clone() 6849*da0073e9SAndroid Build Coastguard Worker 6850*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 6851*da0073e9SAndroid Build Coastguard Worker UserWarning, f"This overload of {func}_ is deprecated"): 6852*da0073e9SAndroid Build Coastguard Worker getattr(out_tensor, func + "_")(1, b1, b2) 6853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, ref * 2), 6854*da0073e9SAndroid Build Coastguard Worker getattr(res3, func + "_")(b1, b2, beta=1) 6855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, res3) 6856*da0073e9SAndroid Build Coastguard Worker 6857*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 6858*da0073e9SAndroid Build Coastguard Worker UserWarning, f"This overload of {func}_ is deprecated"): 6859*da0073e9SAndroid Build Coastguard Worker getattr(out_tensor, func + "_")(1., .5, b1, b2) 6860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, ref * 2.5) 6861*da0073e9SAndroid Build Coastguard Worker getattr(res3, func + "_")(b1, b2, beta=1., alpha=.5) 6862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, res3) 6863*da0073e9SAndroid Build Coastguard Worker 6864*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 6865*da0073e9SAndroid Build Coastguard Worker UserWarning, f"This overload of {func} is deprecated"): 6866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) 6867*da0073e9SAndroid Build Coastguard Worker 6868*da0073e9SAndroid Build Coastguard Worker res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5) 6869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res4, ref * 3), 6870*da0073e9SAndroid Build Coastguard Worker 6871*da0073e9SAndroid Build Coastguard Worker nan = torch.full_like(out_tensor, math.nan) 6872*da0073e9SAndroid Build Coastguard Worker res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) 6873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res5, ref) 6874*da0073e9SAndroid Build Coastguard Worker 6875*da0073e9SAndroid Build Coastguard Worker if b1.is_complex(): 6876*da0073e9SAndroid Build Coastguard Worker res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1j, alpha=.5j) 6877*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res6, out_tensor * .1j + .5j * ref) 6878*da0073e9SAndroid Build Coastguard Worker else: 6879*da0073e9SAndroid Build Coastguard Worker res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1, alpha=.5) 6880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res6, out_tensor * .1 + .5 * ref) 6881*da0073e9SAndroid Build Coastguard Worker 6882*da0073e9SAndroid Build Coastguard Worker res7 = torch.full_like(out_tensor, math.nan) 6883*da0073e9SAndroid Build Coastguard Worker getattr(torch, func)(nan, b1, b2, beta=0, out=res7) 6884*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res7, ref) 6885*da0073e9SAndroid Build Coastguard Worker 6886*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 6887*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6888*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6889*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.05) 6890*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.05) 6891*da0073e9SAndroid Build Coastguard Worker def test_addbmm(self, device, dtype): 6892*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: 6893*da0073e9SAndroid Build Coastguard Worker # cuBLAS does not guarantee BFloat16 support on SM < 53. 6894*da0073e9SAndroid Build Coastguard Worker # So on PyTorch, we consider BFloat16 support on SM < 53 as 6895*da0073e9SAndroid Build Coastguard Worker # undefined bahavior 6896*da0073e9SAndroid Build Coastguard Worker return 6897*da0073e9SAndroid Build Coastguard Worker 6898*da0073e9SAndroid Build Coastguard Worker num_batches = 2 6899*da0073e9SAndroid Build Coastguard Worker M, N, O = 16, 17, 18 6900*da0073e9SAndroid Build Coastguard Worker 6901*da0073e9SAndroid Build Coastguard Worker is_supported = True 6902*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 6903*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 6904*da0073e9SAndroid Build Coastguard Worker self.precision = 1 # 43 vs 43.75 6905*da0073e9SAndroid Build Coastguard Worker else: 6906*da0073e9SAndroid Build Coastguard Worker is_supported = TEST_WITH_ROCM or SM53OrLater 6907*da0073e9SAndroid Build Coastguard Worker 6908*da0073e9SAndroid Build Coastguard Worker if not is_supported: 6909*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) 6910*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) 6911*da0073e9SAndroid Build Coastguard Worker t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1) 6912*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 6913*da0073e9SAndroid Build Coastguard Worker lambda: torch.addbmm(t, b1, b2)) 6914*da0073e9SAndroid Build Coastguard Worker return 6915*da0073e9SAndroid Build Coastguard Worker 6916*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 6917*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 6918*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2]) 6919*da0073e9SAndroid Build Coastguard Worker 6920*da0073e9SAndroid Build Coastguard Worker def generate_tensor(): 6921*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 6922*da0073e9SAndroid Build Coastguard Worker # transposed tensors 6923*da0073e9SAndroid Build Coastguard Worker for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): 6924*da0073e9SAndroid Build Coastguard Worker for perm3 in itertools.permutations((0, 1)): 6925*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) * 0.1 6926*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) * 0.1 6927*da0073e9SAndroid Build Coastguard Worker b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 6928*da0073e9SAndroid Build Coastguard Worker b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 6929*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 6930*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 6931*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype).sum(0) 6932*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) 6933*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 6934*da0073e9SAndroid Build Coastguard Worker # broadcasting tensors 6935*da0073e9SAndroid Build Coastguard Worker for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 6936*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 6937*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 6938*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) * 0.1 6939*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) * 0.1 6940*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 6941*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 6942*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype).sum(0) 6943*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 6944*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 6945*da0073e9SAndroid Build Coastguard Worker # zero-sized tensors 6946*da0073e9SAndroid Build Coastguard Worker for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 6947*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 6948*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 6949*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) * 0.1 6950*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) * 0.1 6951*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 6952*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 6953*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype).sum(0) 6954*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 6955*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 6956*da0073e9SAndroid Build Coastguard Worker 6957*da0073e9SAndroid Build Coastguard Worker for b1, b2, ref, out_tensor in generate_tensor(): 6958*da0073e9SAndroid Build Coastguard Worker self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) 6959*da0073e9SAndroid Build Coastguard Worker 6960*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) 6961*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6962*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) 6963*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.05) 6964*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.05) 6965*da0073e9SAndroid Build Coastguard Worker def test_baddbmm(self, device, dtype): 6966*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: 6967*da0073e9SAndroid Build Coastguard Worker # cuBLAS does not guarantee BFloat16 support on SM < 53. 6968*da0073e9SAndroid Build Coastguard Worker # So on PyTorch, we consider BFloat16 support on SM < 53 as 6969*da0073e9SAndroid Build Coastguard Worker # undefined bahavior 6970*da0073e9SAndroid Build Coastguard Worker return 6971*da0073e9SAndroid Build Coastguard Worker 6972*da0073e9SAndroid Build Coastguard Worker num_batches = 10 6973*da0073e9SAndroid Build Coastguard Worker M, N, O = 12, 8, 50 6974*da0073e9SAndroid Build Coastguard Worker 6975*da0073e9SAndroid Build Coastguard Worker is_supported = True 6976*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16 and self.device_type == 'cuda': 6977*da0073e9SAndroid Build Coastguard Worker is_supported = TEST_WITH_ROCM or SM53OrLater 6978*da0073e9SAndroid Build Coastguard Worker 6979*da0073e9SAndroid Build Coastguard Worker if not is_supported: 6980*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) 6981*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) 6982*da0073e9SAndroid Build Coastguard Worker t = make_tensor((num_batches, M, O), dtype=dtype, device=device, low=-1, high=1) 6983*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 6984*da0073e9SAndroid Build Coastguard Worker lambda: torch.baddbmm(t, b1, b2)) 6985*da0073e9SAndroid Build Coastguard Worker return 6986*da0073e9SAndroid Build Coastguard Worker 6987*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 6988*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 6989*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2]) 6990*da0073e9SAndroid Build Coastguard Worker 6991*da0073e9SAndroid Build Coastguard Worker def generate_tensor(): 6992*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32 6993*da0073e9SAndroid Build Coastguard Worker # transposed tensors 6994*da0073e9SAndroid Build Coastguard Worker for perm1, perm2, perm3 in itertools.product(itertools.permutations((0, 1, 2)), repeat=3): 6995*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) 6996*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) 6997*da0073e9SAndroid Build Coastguard Worker b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 6998*da0073e9SAndroid Build Coastguard Worker b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 6999*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 7000*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 7001*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 7002*da0073e9SAndroid Build Coastguard Worker out_tensor = out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) 7003*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 7004*da0073e9SAndroid Build Coastguard Worker # broadcasting tensors 7005*da0073e9SAndroid Build Coastguard Worker for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 7006*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 7007*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 7008*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) 7009*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) 7010*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 7011*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 7012*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 7013*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 7014*da0073e9SAndroid Build Coastguard Worker # zero-sized tensors 7015*da0073e9SAndroid Build Coastguard Worker for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 7016*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 7017*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 7018*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2) 7019*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2) 7020*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 7021*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 7022*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 7023*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 7024*da0073e9SAndroid Build Coastguard Worker 7025*da0073e9SAndroid Build Coastguard Worker for b1, b2, ref, out_tensor in generate_tensor(): 7026*da0073e9SAndroid Build Coastguard Worker self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) 7027*da0073e9SAndroid Build Coastguard Worker 7028*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) 7029*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7030*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7031*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 7032*da0073e9SAndroid Build Coastguard Worker def test_pinverse(self, device, dtype): 7033*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 7034*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_fullrank, device=device, dtype=dtype) 7035*da0073e9SAndroid Build Coastguard Worker 7036*da0073e9SAndroid Build Coastguard Worker def run_test(M): 7037*da0073e9SAndroid Build Coastguard Worker # Testing against definition for pseudo-inverses 7038*da0073e9SAndroid Build Coastguard Worker MPI = torch.pinverse(M) 7039*da0073e9SAndroid Build Coastguard Worker MPI_ = MPI.cpu().numpy() 7040*da0073e9SAndroid Build Coastguard Worker M_ = M.cpu().numpy() 7041*da0073e9SAndroid Build Coastguard Worker if M.numel() > 0: 7042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(M_, np.matmul(np.matmul(M_, MPI_), M_)) 7043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MPI_, np.matmul(np.matmul(MPI_, M_), MPI_)) 7044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.matmul(M_, MPI_), np.matmul(M_, MPI_).swapaxes(-2, -1).conj()) 7045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.matmul(MPI_, M_), np.matmul(MPI_, M_).swapaxes(-2, -1).conj()) 7046*da0073e9SAndroid Build Coastguard Worker else: 7047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2])) 7048*da0073e9SAndroid Build Coastguard Worker for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5), # square matrices 7049*da0073e9SAndroid Build Coastguard Worker (3, 2), (5, 3, 2), (7, 5, 3, 2), # fat matrices 7050*da0073e9SAndroid Build Coastguard Worker (2, 3), (5, 2, 3), (7, 5, 2, 3), # thin matrices 7051*da0073e9SAndroid Build Coastguard Worker (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices 7052*da0073e9SAndroid Build Coastguard Worker M = torch.randn(*sizes, dtype=dtype, device=device) 7053*da0073e9SAndroid Build Coastguard Worker run_test(M) 7054*da0073e9SAndroid Build Coastguard Worker 7055*da0073e9SAndroid Build Coastguard Worker # Test inverse and pseudo-inverse for invertible matrix 7056*da0073e9SAndroid Build Coastguard Worker for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]: 7057*da0073e9SAndroid Build Coastguard Worker matsize = sizes[-1] 7058*da0073e9SAndroid Build Coastguard Worker batchdims = sizes[:-2] 7059*da0073e9SAndroid Build Coastguard Worker M = make_arg(*batchdims, matsize, matsize) 7060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M), 7061*da0073e9SAndroid Build Coastguard Worker atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix') 7062*da0073e9SAndroid Build Coastguard Worker 7063*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7064*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 7065*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 7066*da0073e9SAndroid Build Coastguard Worker def test_matrix_power_non_negative(self, device, dtype): 7067*da0073e9SAndroid Build Coastguard Worker def check(*size): 7068*da0073e9SAndroid Build Coastguard Worker t = make_tensor(size, dtype=dtype, device=device) 7069*da0073e9SAndroid Build Coastguard Worker for n in range(8): 7070*da0073e9SAndroid Build Coastguard Worker res = torch.linalg.matrix_power(t, n) 7071*da0073e9SAndroid Build Coastguard Worker ref = np.linalg.matrix_power(t.cpu().numpy(), n) 7072*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu(), torch.from_numpy(ref)) 7073*da0073e9SAndroid Build Coastguard Worker 7074*da0073e9SAndroid Build Coastguard Worker check(0, 0) 7075*da0073e9SAndroid Build Coastguard Worker check(1, 1) 7076*da0073e9SAndroid Build Coastguard Worker check(5, 5) 7077*da0073e9SAndroid Build Coastguard Worker check(0, 3, 3) 7078*da0073e9SAndroid Build Coastguard Worker check(2, 3, 3) 7079*da0073e9SAndroid Build Coastguard Worker 7080*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7081*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 7082*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 7083*da0073e9SAndroid Build Coastguard Worker def test_matrix_power_negative(self, device, dtype): 7084*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 7085*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_fullrank, device=device, dtype=dtype) 7086*da0073e9SAndroid Build Coastguard Worker 7087*da0073e9SAndroid Build Coastguard Worker def check(*size): 7088*da0073e9SAndroid Build Coastguard Worker t = make_arg(*size) 7089*da0073e9SAndroid Build Coastguard Worker for n in range(-7, 0): 7090*da0073e9SAndroid Build Coastguard Worker res = torch.linalg.matrix_power(t, n) 7091*da0073e9SAndroid Build Coastguard Worker ref = np.linalg.matrix_power(t.cpu().numpy(), n) 7092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu(), torch.from_numpy(ref)) 7093*da0073e9SAndroid Build Coastguard Worker 7094*da0073e9SAndroid Build Coastguard Worker check(0, 0) 7095*da0073e9SAndroid Build Coastguard Worker check(5, 5) 7096*da0073e9SAndroid Build Coastguard Worker check(2, 0, 0) 7097*da0073e9SAndroid Build Coastguard Worker check(0, 3, 3) 7098*da0073e9SAndroid Build Coastguard Worker check(2, 3, 3) 7099*da0073e9SAndroid Build Coastguard Worker check(2, 3, 5, 5) 7100*da0073e9SAndroid Build Coastguard Worker 7101*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7102*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7103*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.complex64) 7104*da0073e9SAndroid Build Coastguard Worker def test_linalg_matrix_exp_utils(self, device, dtype): 7105*da0073e9SAndroid Build Coastguard Worker # test linear combination 7106*da0073e9SAndroid Build Coastguard Worker def run_test(coeff_shape, data_shape): 7107*da0073e9SAndroid Build Coastguard Worker coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float) 7108*da0073e9SAndroid Build Coastguard Worker x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype) 7109*da0073e9SAndroid Build Coastguard Worker 7110*da0073e9SAndroid Build Coastguard Worker res1 = torch._compute_linear_combination(x, coeffs) 7111*da0073e9SAndroid Build Coastguard Worker res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1) 7112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2, atol=1e-5, rtol=0.0) 7113*da0073e9SAndroid Build Coastguard Worker 7114*da0073e9SAndroid Build Coastguard Worker # check `out=` version 7115*da0073e9SAndroid Build Coastguard Worker res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype) 7116*da0073e9SAndroid Build Coastguard Worker torch._compute_linear_combination(x, coeffs, out=res3) 7117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res3, atol=1e-5, rtol=0.0) 7118*da0073e9SAndroid Build Coastguard Worker 7119*da0073e9SAndroid Build Coastguard Worker res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) 7120*da0073e9SAndroid Build Coastguard Worker torch._compute_linear_combination(x, coeffs, out=res4) 7121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0) 7122*da0073e9SAndroid Build Coastguard Worker 7123*da0073e9SAndroid Build Coastguard Worker res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) 7124*da0073e9SAndroid Build Coastguard Worker res5_clone = res5.clone() 7125*da0073e9SAndroid Build Coastguard Worker torch._compute_linear_combination(x, coeffs, out=res5) 7126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0) 7127*da0073e9SAndroid Build Coastguard Worker 7128*da0073e9SAndroid Build Coastguard Worker run_test([1, 3], [2, 2]) 7129*da0073e9SAndroid Build Coastguard Worker run_test([3, 1], [2, 2]) 7130*da0073e9SAndroid Build Coastguard Worker run_test([1, 10], [10, 10]) 7131*da0073e9SAndroid Build Coastguard Worker run_test([10, 1], [10, 10]) 7132*da0073e9SAndroid Build Coastguard Worker run_test([5, 3], [2, 2]) 7133*da0073e9SAndroid Build Coastguard Worker run_test([5, 3], [100, 100]) 7134*da0073e9SAndroid Build Coastguard Worker run_test([3, 4], [3, 3, 3]) 7135*da0073e9SAndroid Build Coastguard Worker run_test([3, 4], [3, 3, 3, 3]) 7136*da0073e9SAndroid Build Coastguard Worker 7137*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/94124 7138*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 7139*da0073e9SAndroid Build Coastguard Worker x = torch.rand([], device=device, dtype=dtype) 7140*da0073e9SAndroid Build Coastguard Worker coeffs = torch.rand([2, 2], device=device, dtype=dtype) 7141*da0073e9SAndroid Build Coastguard Worker res = torch._compute_linear_combination(x, coeffs) 7142*da0073e9SAndroid Build Coastguard Worker 7143*da0073e9SAndroid Build Coastguard Worker @onlyCPU 7144*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7145*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64) 7146*da0073e9SAndroid Build Coastguard Worker def test_linalg_matrix_exp_no_warnings(self, device, dtype): 7147*da0073e9SAndroid Build Coastguard Worker # this tests https://github.com/pytorch/pytorch/issues/80948 7148*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 7149*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 7150*da0073e9SAndroid Build Coastguard Worker tens = 0.5 * torch.randn(10, 3, 3, dtype=dtype, device=device) 7151*da0073e9SAndroid Build Coastguard Worker tens = (0.5 * (tens.transpose(-1, -2) + tens)) 7152*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 7153*da0073e9SAndroid Build Coastguard Worker tens.imag = torch.matrix_exp(tens.imag) 7154*da0073e9SAndroid Build Coastguard Worker self.assertFalse(len(w)) 7155*da0073e9SAndroid Build Coastguard Worker 7156*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7157*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7158*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) 7159*da0073e9SAndroid Build Coastguard Worker def test_linalg_matrix_exp_boundary_cases(self, device, dtype): 7160*da0073e9SAndroid Build Coastguard Worker expm = torch.linalg.matrix_exp 7161*da0073e9SAndroid Build Coastguard Worker 7162*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected a floating point or complex tensor"): 7163*da0073e9SAndroid Build Coastguard Worker expm(torch.randn(3, 3).type(torch.int)) 7164*da0073e9SAndroid Build Coastguard Worker 7165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 7166*da0073e9SAndroid Build Coastguard Worker expm(torch.randn(3)) 7167*da0073e9SAndroid Build Coastguard Worker 7168*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 7169*da0073e9SAndroid Build Coastguard Worker expm(torch.randn(3, 2, 1)) 7170*da0073e9SAndroid Build Coastguard Worker 7171*da0073e9SAndroid Build Coastguard Worker # check 1x1 matrices 7172*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, 1, 1) 7173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expm(x), x.exp()) 7174*da0073e9SAndroid Build Coastguard Worker 7175*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7176*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7177*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) 7178*da0073e9SAndroid Build Coastguard Worker def test_linalg_matrix_exp_perverse_nan_values(self, device, dtype): 7179*da0073e9SAndroid Build Coastguard Worker expm = torch.linalg.matrix_exp 7180*da0073e9SAndroid Build Coastguard Worker 7181*da0073e9SAndroid Build Coastguard Worker def with_nan(x): 7182*da0073e9SAndroid Build Coastguard Worker x[0, 0, 0] = torch.nan 7183*da0073e9SAndroid Build Coastguard Worker return x 7184*da0073e9SAndroid Build Coastguard Worker 7185*da0073e9SAndroid Build Coastguard Worker # Check small batches 7186*da0073e9SAndroid Build Coastguard Worker x = with_nan(torch.randn(1, 1, 1)) 7187*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isnan(expm(x)).any()) 7188*da0073e9SAndroid Build Coastguard Worker x = with_nan(torch.randn(1, 2, 2)) 7189*da0073e9SAndroid Build Coastguard Worker for v in [1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 1000]: 7190*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isnan(expm(x / v)).any()) 7191*da0073e9SAndroid Build Coastguard Worker 7192*da0073e9SAndroid Build Coastguard Worker # Check large batches 7193*da0073e9SAndroid Build Coastguard Worker x = with_nan(torch.randn(2, 2, 2)) 7194*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isnan(expm(x)).any()) 7195*da0073e9SAndroid Build Coastguard Worker x = with_nan(torch.randn(4096, 2, 2)) 7196*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isnan(expm(x)).any()) 7197*da0073e9SAndroid Build Coastguard Worker 7198*da0073e9SAndroid Build Coastguard Worker @slowTest 7199*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7200*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7201*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 7202*da0073e9SAndroid Build Coastguard Worker def test_linalg_matrix_exp_analytic(self, device, dtype): 7203*da0073e9SAndroid Build Coastguard Worker expm = torch.linalg.matrix_exp 7204*da0073e9SAndroid Build Coastguard Worker # check zero matrix 7205*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(20, 20, dtype=dtype, device=device) 7206*da0073e9SAndroid Build Coastguard Worker self.assertTrue((expm(x) == torch.eye(20, 20, dtype=dtype, device=device)).all().item()) 7207*da0073e9SAndroid Build Coastguard Worker 7208*da0073e9SAndroid Build Coastguard Worker def normalize_to_1_operator_norm(sample, desired_norm): 7209*da0073e9SAndroid Build Coastguard Worker sample_norm, _ = sample.abs().sum(-2).max(-1) 7210*da0073e9SAndroid Build Coastguard Worker sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) 7211*da0073e9SAndroid Build Coastguard Worker return sample_to_1_norm * desired_norm 7212*da0073e9SAndroid Build Coastguard Worker 7213*da0073e9SAndroid Build Coastguard Worker def gen_good_cond_number_matrices(*n): 7214*da0073e9SAndroid Build Coastguard Worker """ 7215*da0073e9SAndroid Build Coastguard Worker Generates a diagonally-domimant matrix 7216*da0073e9SAndroid Build Coastguard Worker with the eigenvalues centered at 1 7217*da0073e9SAndroid Build Coastguard Worker and the radii at most (n[-1] - 1) / (n[-2] ** 2) 7218*da0073e9SAndroid Build Coastguard Worker """ 7219*da0073e9SAndroid Build Coastguard Worker identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) 7220*da0073e9SAndroid Build Coastguard Worker x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) 7221*da0073e9SAndroid Build Coastguard Worker x = (x - x * identity) + identity 7222*da0073e9SAndroid Build Coastguard Worker return x 7223*da0073e9SAndroid Build Coastguard Worker 7224*da0073e9SAndroid Build Coastguard Worker def run_test(*n): 7225*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 7226*da0073e9SAndroid Build Coastguard Worker thetas = [ 7227*da0073e9SAndroid Build Coastguard Worker 1.192092800768788e-07, # deg 1 7228*da0073e9SAndroid Build Coastguard Worker 5.978858893805233e-04, # deg 2 7229*da0073e9SAndroid Build Coastguard Worker 5.116619363445086e-02, # deg 4 7230*da0073e9SAndroid Build Coastguard Worker 5.800524627688768e-01, # deg 8 7231*da0073e9SAndroid Build Coastguard Worker 1.461661507209034e+00, # deg 12 7232*da0073e9SAndroid Build Coastguard Worker 3.010066362817634e+00 # deg 18 7233*da0073e9SAndroid Build Coastguard Worker ] 7234*da0073e9SAndroid Build Coastguard Worker else: # if torch.double 7235*da0073e9SAndroid Build Coastguard Worker thetas = [ 7236*da0073e9SAndroid Build Coastguard Worker 2.220446049250313e-16, # deg 1 7237*da0073e9SAndroid Build Coastguard Worker 2.580956802971767e-08, # deg 2 7238*da0073e9SAndroid Build Coastguard Worker 3.397168839976962e-04, # deg 4 7239*da0073e9SAndroid Build Coastguard Worker 4.991228871115323e-02, # deg 8 7240*da0073e9SAndroid Build Coastguard Worker 2.996158913811580e-01, # deg 12 7241*da0073e9SAndroid Build Coastguard Worker 1.090863719290036e+00 # deg 18 7242*da0073e9SAndroid Build Coastguard Worker ] 7243*da0073e9SAndroid Build Coastguard Worker 7244*da0073e9SAndroid Build Coastguard Worker # generate input 7245*da0073e9SAndroid Build Coastguard Worker q = gen_good_cond_number_matrices(*n) 7246*da0073e9SAndroid Build Coastguard Worker q_ = q.cpu().numpy() 7247*da0073e9SAndroid Build Coastguard Worker qinv = torch.inverse(q) 7248*da0073e9SAndroid Build Coastguard Worker qinv_ = qinv.cpu().numpy() 7249*da0073e9SAndroid Build Coastguard Worker d = torch.randn(n[:-1], dtype=dtype, device=device) 7250*da0073e9SAndroid Build Coastguard Worker x = torch.from_numpy( 7251*da0073e9SAndroid Build Coastguard Worker np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device) 7252*da0073e9SAndroid Build Coastguard Worker x_norm, _ = x.abs().sum(-2).max(-1) 7253*da0073e9SAndroid Build Coastguard Worker 7254*da0073e9SAndroid Build Coastguard Worker # test simple analytic whatever norm generated 7255*da0073e9SAndroid Build Coastguard Worker mexp = expm(x) 7256*da0073e9SAndroid Build Coastguard Worker mexp_analytic = np.matmul( 7257*da0073e9SAndroid Build Coastguard Worker q_, 7258*da0073e9SAndroid Build Coastguard Worker np.matmul( 7259*da0073e9SAndroid Build Coastguard Worker torch.diag_embed(d.exp()).cpu().numpy(), 7260*da0073e9SAndroid Build Coastguard Worker qinv_ 7261*da0073e9SAndroid Build Coastguard Worker ) 7262*da0073e9SAndroid Build Coastguard Worker ) 7263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) 7264*da0073e9SAndroid Build Coastguard Worker 7265*da0073e9SAndroid Build Coastguard Worker # generate norms to test different degree expansions 7266*da0073e9SAndroid Build Coastguard Worker sample_norms = [] 7267*da0073e9SAndroid Build Coastguard Worker for i in range(len(thetas) - 1): 7268*da0073e9SAndroid Build Coastguard Worker sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) 7269*da0073e9SAndroid Build Coastguard Worker sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] 7270*da0073e9SAndroid Build Coastguard Worker 7271*da0073e9SAndroid Build Coastguard Worker # matrices to equal norm 7272*da0073e9SAndroid Build Coastguard Worker for sample_norm in sample_norms: 7273*da0073e9SAndroid Build Coastguard Worker x_normalized = normalize_to_1_operator_norm(x, sample_norm) 7274*da0073e9SAndroid Build Coastguard Worker 7275*da0073e9SAndroid Build Coastguard Worker mexp = expm(x_normalized) 7276*da0073e9SAndroid Build Coastguard Worker mexp_analytic = np.matmul( 7277*da0073e9SAndroid Build Coastguard Worker q_, 7278*da0073e9SAndroid Build Coastguard Worker np.matmul( 7279*da0073e9SAndroid Build Coastguard Worker torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(), 7280*da0073e9SAndroid Build Coastguard Worker qinv_ 7281*da0073e9SAndroid Build Coastguard Worker ) 7282*da0073e9SAndroid Build Coastguard Worker ) 7283*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) 7284*da0073e9SAndroid Build Coastguard Worker 7285*da0073e9SAndroid Build Coastguard Worker # single matrix 7286*da0073e9SAndroid Build Coastguard Worker run_test(2, 2) 7287*da0073e9SAndroid Build Coastguard Worker run_test(3, 3) 7288*da0073e9SAndroid Build Coastguard Worker run_test(4, 4) 7289*da0073e9SAndroid Build Coastguard Worker run_test(5, 5) 7290*da0073e9SAndroid Build Coastguard Worker run_test(100, 100) 7291*da0073e9SAndroid Build Coastguard Worker run_test(200, 200) 7292*da0073e9SAndroid Build Coastguard Worker 7293*da0073e9SAndroid Build Coastguard Worker # small batch of matrices 7294*da0073e9SAndroid Build Coastguard Worker run_test(3, 2, 2) 7295*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 3) 7296*da0073e9SAndroid Build Coastguard Worker run_test(3, 4, 4) 7297*da0073e9SAndroid Build Coastguard Worker run_test(3, 5, 5) 7298*da0073e9SAndroid Build Coastguard Worker run_test(3, 100, 100) 7299*da0073e9SAndroid Build Coastguard Worker run_test(3, 200, 200) 7300*da0073e9SAndroid Build Coastguard Worker 7301*da0073e9SAndroid Build Coastguard Worker # large batch of matrices 7302*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 2, 2) 7303*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 3, 3) 7304*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 4, 4) 7305*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 5, 5) 7306*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 100, 100) 7307*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 200, 200) 7308*da0073e9SAndroid Build Coastguard Worker 7309*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7310*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7311*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 7312*da0073e9SAndroid Build Coastguard Worker def test_linalg_matrix_exp_batch(self, device, dtype): 7313*da0073e9SAndroid Build Coastguard Worker 7314*da0073e9SAndroid Build Coastguard Worker def run_test(*n): 7315*da0073e9SAndroid Build Coastguard Worker tensors_batch = torch.zeros(n, dtype=dtype, device=device) 7316*da0073e9SAndroid Build Coastguard Worker tensors_batch = tensors_batch.view(-1, n[-2], n[-1]) 7317*da0073e9SAndroid Build Coastguard Worker 7318*da0073e9SAndroid Build Coastguard Worker num_matrices = tensors_batch.size(0) 7319*da0073e9SAndroid Build Coastguard Worker tensors_list = [] 7320*da0073e9SAndroid Build Coastguard Worker for i in range(num_matrices): 7321*da0073e9SAndroid Build Coastguard Worker tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device)) 7322*da0073e9SAndroid Build Coastguard Worker 7323*da0073e9SAndroid Build Coastguard Worker for i in range(num_matrices): 7324*da0073e9SAndroid Build Coastguard Worker tensors_batch[i, ...] = tensors_list[i] 7325*da0073e9SAndroid Build Coastguard Worker 7326*da0073e9SAndroid Build Coastguard Worker tensors_exp_map = (torch.linalg.matrix_exp(x) for x in tensors_list) 7327*da0073e9SAndroid Build Coastguard Worker tensors_exp_batch = torch.linalg.matrix_exp(tensors_batch) 7328*da0073e9SAndroid Build Coastguard Worker 7329*da0073e9SAndroid Build Coastguard Worker for i, tensor_exp in enumerate(tensors_exp_map): 7330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensors_exp_batch[i, ...], tensor_exp) 7331*da0073e9SAndroid Build Coastguard Worker 7332*da0073e9SAndroid Build Coastguard Worker # small batch of matrices 7333*da0073e9SAndroid Build Coastguard Worker run_test(3, 2, 2) 7334*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 3) 7335*da0073e9SAndroid Build Coastguard Worker run_test(3, 4, 4) 7336*da0073e9SAndroid Build Coastguard Worker run_test(3, 5, 5) 7337*da0073e9SAndroid Build Coastguard Worker 7338*da0073e9SAndroid Build Coastguard Worker # large batch of matrices 7339*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 2, 2) 7340*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 3, 3) 7341*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 4, 4) 7342*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 5, 5) 7343*da0073e9SAndroid Build Coastguard Worker 7344*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7345*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7346*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 7347*da0073e9SAndroid Build Coastguard Worker def test_linalg_matrix_exp_compare_with_taylor(self, device, dtype): 7348*da0073e9SAndroid Build Coastguard Worker 7349*da0073e9SAndroid Build Coastguard Worker def normalize_to_1_operator_norm(sample, desired_norm): 7350*da0073e9SAndroid Build Coastguard Worker sample_norm, _ = sample.abs().sum(-2).max(-1) 7351*da0073e9SAndroid Build Coastguard Worker sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) 7352*da0073e9SAndroid Build Coastguard Worker return sample_to_1_norm * desired_norm 7353*da0073e9SAndroid Build Coastguard Worker 7354*da0073e9SAndroid Build Coastguard Worker def gen_good_cond_number_matrices(*n): 7355*da0073e9SAndroid Build Coastguard Worker """ 7356*da0073e9SAndroid Build Coastguard Worker Generates a diagonally-domimant matrix 7357*da0073e9SAndroid Build Coastguard Worker with the eigenvalues centered at 1 7358*da0073e9SAndroid Build Coastguard Worker and the radii at most (n[-1] - 1) / (n[-2] ** 2) 7359*da0073e9SAndroid Build Coastguard Worker """ 7360*da0073e9SAndroid Build Coastguard Worker identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) 7361*da0073e9SAndroid Build Coastguard Worker x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) 7362*da0073e9SAndroid Build Coastguard Worker x = (x - x * identity) + identity 7363*da0073e9SAndroid Build Coastguard Worker return x 7364*da0073e9SAndroid Build Coastguard Worker 7365*da0073e9SAndroid Build Coastguard Worker def get_taylor_approximation(a, deg): 7366*da0073e9SAndroid Build Coastguard Worker a_ = a.cpu().numpy() 7367*da0073e9SAndroid Build Coastguard Worker identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a) 7368*da0073e9SAndroid Build Coastguard Worker res = identity.cpu().numpy() 7369*da0073e9SAndroid Build Coastguard Worker taylor_term = identity.cpu().numpy() 7370*da0073e9SAndroid Build Coastguard Worker 7371*da0073e9SAndroid Build Coastguard Worker for i in range(1, deg + 1): 7372*da0073e9SAndroid Build Coastguard Worker taylor_term = np.matmul(a_, taylor_term) / i 7373*da0073e9SAndroid Build Coastguard Worker res = res + taylor_term 7374*da0073e9SAndroid Build Coastguard Worker 7375*da0073e9SAndroid Build Coastguard Worker return res 7376*da0073e9SAndroid Build Coastguard Worker 7377*da0073e9SAndroid Build Coastguard Worker def scale_square(a, deg): 7378*da0073e9SAndroid Build Coastguard Worker if a.abs().pow(2).sum().sqrt() < 1.0: 7379*da0073e9SAndroid Build Coastguard Worker return get_taylor_approximation(a, 12) 7380*da0073e9SAndroid Build Coastguard Worker else: 7381*da0073e9SAndroid Build Coastguard Worker s = int(torch.log2(a.abs().pow(2).sum().sqrt()).ceil().item()) 7382*da0073e9SAndroid Build Coastguard Worker b = a / (2 ** s) 7383*da0073e9SAndroid Build Coastguard Worker b = get_taylor_approximation(b, 18) 7384*da0073e9SAndroid Build Coastguard Worker for _ in range(s): 7385*da0073e9SAndroid Build Coastguard Worker b = np.matmul(b, b) 7386*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(b).to(a.device) 7387*da0073e9SAndroid Build Coastguard Worker 7388*da0073e9SAndroid Build Coastguard Worker def run_test(*n): 7389*da0073e9SAndroid Build Coastguard Worker degs = [1, 2, 4, 8, 12, 18] 7390*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 7391*da0073e9SAndroid Build Coastguard Worker thetas = [ 7392*da0073e9SAndroid Build Coastguard Worker 1.192092800768788e-07, # deg 1 7393*da0073e9SAndroid Build Coastguard Worker 5.978858893805233e-04, # deg 2 7394*da0073e9SAndroid Build Coastguard Worker 5.116619363445086e-02, # deg 4 7395*da0073e9SAndroid Build Coastguard Worker 5.800524627688768e-01, # deg 8 7396*da0073e9SAndroid Build Coastguard Worker 1.461661507209034e+00, # deg 12 7397*da0073e9SAndroid Build Coastguard Worker 3.010066362817634e+00 # deg 18 7398*da0073e9SAndroid Build Coastguard Worker ] 7399*da0073e9SAndroid Build Coastguard Worker else: # if torch.double 7400*da0073e9SAndroid Build Coastguard Worker thetas = [ 7401*da0073e9SAndroid Build Coastguard Worker 2.220446049250313e-16, # deg 1 7402*da0073e9SAndroid Build Coastguard Worker 2.580956802971767e-08, # deg 2 7403*da0073e9SAndroid Build Coastguard Worker 3.397168839976962e-04, # deg 4 7404*da0073e9SAndroid Build Coastguard Worker 4.991228871115323e-02, # deg 8 7405*da0073e9SAndroid Build Coastguard Worker 2.996158913811580e-01, # deg 12 7406*da0073e9SAndroid Build Coastguard Worker 1.090863719290036e+00 # deg 18 7407*da0073e9SAndroid Build Coastguard Worker ] 7408*da0073e9SAndroid Build Coastguard Worker 7409*da0073e9SAndroid Build Coastguard Worker # generate norms to test different degree expansions 7410*da0073e9SAndroid Build Coastguard Worker sample_norms = [] 7411*da0073e9SAndroid Build Coastguard Worker for i in range(len(thetas) - 1): 7412*da0073e9SAndroid Build Coastguard Worker sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) 7413*da0073e9SAndroid Build Coastguard Worker sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] 7414*da0073e9SAndroid Build Coastguard Worker degs = [degs[0]] + degs 7415*da0073e9SAndroid Build Coastguard Worker 7416*da0073e9SAndroid Build Coastguard Worker for sample_norm, deg in zip(sample_norms, degs): 7417*da0073e9SAndroid Build Coastguard Worker x = gen_good_cond_number_matrices(*n) 7418*da0073e9SAndroid Build Coastguard Worker x = normalize_to_1_operator_norm(x, sample_norm) 7419*da0073e9SAndroid Build Coastguard Worker 7420*da0073e9SAndroid Build Coastguard Worker mexp = torch.linalg.matrix_exp(x) 7421*da0073e9SAndroid Build Coastguard Worker mexp_taylor = scale_square(x, deg) 7422*da0073e9SAndroid Build Coastguard Worker 7423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0) 7424*da0073e9SAndroid Build Coastguard Worker 7425*da0073e9SAndroid Build Coastguard Worker # single matrix 7426*da0073e9SAndroid Build Coastguard Worker run_test(2, 2) 7427*da0073e9SAndroid Build Coastguard Worker run_test(3, 3) 7428*da0073e9SAndroid Build Coastguard Worker run_test(4, 4) 7429*da0073e9SAndroid Build Coastguard Worker run_test(5, 5) 7430*da0073e9SAndroid Build Coastguard Worker 7431*da0073e9SAndroid Build Coastguard Worker # small batch of matrices 7432*da0073e9SAndroid Build Coastguard Worker run_test(3, 2, 2) 7433*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 3) 7434*da0073e9SAndroid Build Coastguard Worker run_test(3, 4, 4) 7435*da0073e9SAndroid Build Coastguard Worker run_test(3, 5, 5) 7436*da0073e9SAndroid Build Coastguard Worker 7437*da0073e9SAndroid Build Coastguard Worker # large batch of matrices 7438*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 2, 2) 7439*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 3, 3) 7440*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 4, 4) 7441*da0073e9SAndroid Build Coastguard Worker run_test(3, 3, 5, 5) 7442*da0073e9SAndroid Build Coastguard Worker 7443*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7444*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7445*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 7446*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 7447*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 7448*da0073e9SAndroid Build Coastguard Worker def test_slogdet(self, device, dtype): 7449*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import (random_hermitian_matrix, random_hermitian_psd_matrix, 7450*da0073e9SAndroid Build Coastguard Worker random_hermitian_pd_matrix, random_square_matrix_of_rank) 7451*da0073e9SAndroid Build Coastguard Worker 7452*da0073e9SAndroid Build Coastguard Worker # mat_chars denotes matrix characteristics 7453*da0073e9SAndroid Build Coastguard Worker # possible values are: hermitian, hermitian_psd, hermitian_pd, singular, non_singular 7454*da0073e9SAndroid Build Coastguard Worker def run_test(matsize, batchdims, mat_chars): 7455*da0073e9SAndroid Build Coastguard Worker num_matrices = np.prod(batchdims) 7456*da0073e9SAndroid Build Coastguard Worker list_of_matrices = [] 7457*da0073e9SAndroid Build Coastguard Worker if num_matrices != 0: 7458*da0073e9SAndroid Build Coastguard Worker for idx in range(num_matrices): 7459*da0073e9SAndroid Build Coastguard Worker mat_type = idx % len(mat_chars) 7460*da0073e9SAndroid Build Coastguard Worker if mat_chars[mat_type] == 'hermitian': 7461*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_hermitian_matrix(matsize, dtype=dtype, device=device)) 7462*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'hermitian_psd': 7463*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_hermitian_psd_matrix(matsize, dtype=dtype, device=device)) 7464*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'hermitian_pd': 7465*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_hermitian_pd_matrix(matsize, dtype=dtype, device=device)) 7466*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'singular': 7467*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) 7468*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'non_singular': 7469*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) 7470*da0073e9SAndroid Build Coastguard Worker full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) 7471*da0073e9SAndroid Build Coastguard Worker else: 7472*da0073e9SAndroid Build Coastguard Worker full_tensor = torch.randn(*batchdims, matsize, matsize, dtype=dtype, device=device) 7473*da0073e9SAndroid Build Coastguard Worker 7474*da0073e9SAndroid Build Coastguard Worker actual_value = torch.linalg.slogdet(full_tensor) 7475*da0073e9SAndroid Build Coastguard Worker expected_value = np.linalg.slogdet(full_tensor.cpu().numpy()) 7476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_value[0], actual_value[0], atol=self.precision, rtol=self.precision) 7477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_value[1], actual_value[1], atol=self.precision, rtol=self.precision) 7478*da0073e9SAndroid Build Coastguard Worker 7479*da0073e9SAndroid Build Coastguard Worker # test out=variant 7480*da0073e9SAndroid Build Coastguard Worker sign_out = torch.empty_like(actual_value[0]) 7481*da0073e9SAndroid Build Coastguard Worker logabsdet_out = torch.empty_like(actual_value[1]) 7482*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.slogdet(full_tensor, out=(sign_out, logabsdet_out)) 7483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans[0], sign_out) 7484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans[1], logabsdet_out) 7485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sign_out, actual_value[0]) 7486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logabsdet_out, actual_value[1]) 7487*da0073e9SAndroid Build Coastguard Worker 7488*da0073e9SAndroid Build Coastguard Worker for matsize, batchdims in itertools.product([0, 3, 5], [(0,), (3,), (5, 3)]): 7489*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['hermitian_pd']) 7490*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['singular']) 7491*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['non_singular']) 7492*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['hermitian', 'hermitian_pd', 'hermitian_psd']) 7493*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['singular', 'non_singular']) 7494*da0073e9SAndroid Build Coastguard Worker 7495*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7496*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7497*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 7498*da0073e9SAndroid Build Coastguard Worker def test_slogdet_errors_and_warnings(self, device, dtype): 7499*da0073e9SAndroid Build Coastguard Worker # slogdet requires the input to be a square matrix or batch of square matrices 7500*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device, dtype=dtype) 7501*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): 7502*da0073e9SAndroid Build Coastguard Worker torch.linalg.slogdet(a) 7503*da0073e9SAndroid Build Coastguard Worker 7504*da0073e9SAndroid Build Coastguard Worker # slogdet requires the input to be at least 2 dimensional tensor 7505*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, device=device, dtype=dtype) 7506*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): 7507*da0073e9SAndroid Build Coastguard Worker torch.linalg.slogdet(a) 7508*da0073e9SAndroid Build Coastguard Worker 7509*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, device=device, dtype=torch.bfloat16) 7510*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Low precision dtypes not supported'): 7511*da0073e9SAndroid Build Coastguard Worker torch.linalg.slogdet(a) 7512*da0073e9SAndroid Build Coastguard Worker 7513*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 7514*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 3, device=device, dtype=dtype) 7515*da0073e9SAndroid Build Coastguard Worker sign_out = torch.empty(1, device=device, dtype=dtype) 7516*da0073e9SAndroid Build Coastguard Worker real_dtype = a.real.dtype if dtype.is_complex else dtype 7517*da0073e9SAndroid Build Coastguard Worker logabsdet_out = torch.empty(1, device=device, dtype=real_dtype) 7518*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 7519*da0073e9SAndroid Build Coastguard Worker # Trigger warning 7520*da0073e9SAndroid Build Coastguard Worker torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) 7521*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 7522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 7523*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 7524*da0073e9SAndroid Build Coastguard Worker 7525*da0073e9SAndroid Build Coastguard Worker # device should match 7526*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7527*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 7528*da0073e9SAndroid Build Coastguard Worker sign_out = torch.empty(0, device=wrong_device, dtype=dtype) 7529*da0073e9SAndroid Build Coastguard Worker logabsdet_out = torch.empty(0, device=wrong_device, dtype=real_dtype) 7530*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): 7531*da0073e9SAndroid Build Coastguard Worker torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) 7532*da0073e9SAndroid Build Coastguard Worker 7533*da0073e9SAndroid Build Coastguard Worker # FIXME One of the backends of lu_factor fails in windows. I haven't investigated which or why 7534*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/75225 7535*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 7536*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 7537*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7538*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 7539*da0073e9SAndroid Build Coastguard Worker def test_det_logdet_slogdet(self, device, dtype): 7540*da0073e9SAndroid Build Coastguard Worker def reference_slogdet(M): 7541*da0073e9SAndroid Build Coastguard Worker sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) 7542*da0073e9SAndroid Build Coastguard Worker return M.new_tensor(sdet), M.new_tensor(logabsdet) 7543*da0073e9SAndroid Build Coastguard Worker 7544*da0073e9SAndroid Build Coastguard Worker def test_single_det(M, target, desc): 7545*da0073e9SAndroid Build Coastguard Worker target_sdet, target_logabsdet = target 7546*da0073e9SAndroid Build Coastguard Worker 7547*da0073e9SAndroid Build Coastguard Worker det = M.det() 7548*da0073e9SAndroid Build Coastguard Worker logdet = M.logdet() 7549*da0073e9SAndroid Build Coastguard Worker sdet, logabsdet = M.slogdet() 7550*da0073e9SAndroid Build Coastguard Worker linalg_sdet, linalg_logabsdet = torch.linalg.slogdet(M) 7551*da0073e9SAndroid Build Coastguard Worker 7552*da0073e9SAndroid Build Coastguard Worker # Test det 7553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(det, target_sdet * target_logabsdet.exp(), 7554*da0073e9SAndroid Build Coastguard Worker atol=1e-6, rtol=0, msg=f'{desc} (det)') 7555*da0073e9SAndroid Build Coastguard Worker 7556*da0073e9SAndroid Build Coastguard Worker # Test slogdet 7557*da0073e9SAndroid Build Coastguard Worker # Compare the overall value rather than individual parts because of 7558*da0073e9SAndroid Build Coastguard Worker # precision issues when det is near zero. 7559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), 7560*da0073e9SAndroid Build Coastguard Worker atol=1e-6, rtol=0, msg=f'{desc} (slogdet)') 7561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(), 7562*da0073e9SAndroid Build Coastguard Worker atol=1e-6, rtol=0, msg=f'{desc} (linalg_slogdet)') 7563*da0073e9SAndroid Build Coastguard Worker 7564*da0073e9SAndroid Build Coastguard Worker # Test logdet 7565*da0073e9SAndroid Build Coastguard Worker # Compare logdet against our own pytorch slogdet because they should 7566*da0073e9SAndroid Build Coastguard Worker # be consistent, while it may behave slightly differently with other 7567*da0073e9SAndroid Build Coastguard Worker # slogdet implementations when det is near zero due to precision 7568*da0073e9SAndroid Build Coastguard Worker # issues. 7569*da0073e9SAndroid Build Coastguard Worker if sdet.item() < 0: 7570*da0073e9SAndroid Build Coastguard Worker self.assertTrue(logdet.item() != logdet.item(), f'{desc} (logdet negative case)') 7571*da0073e9SAndroid Build Coastguard Worker else: 7572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logdet.exp(), target_logabsdet.exp(), 7573*da0073e9SAndroid Build Coastguard Worker atol=1e-6, rtol=0, msg=f'{desc} (logdet non-negative case)') 7574*da0073e9SAndroid Build Coastguard Worker 7575*da0073e9SAndroid Build Coastguard Worker eye = torch.eye(5, dtype=dtype, device=device) 7576*da0073e9SAndroid Build Coastguard Worker test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity') 7577*da0073e9SAndroid Build Coastguard Worker # Testing bug in #34061 (https://github.com/pytorch/pytorch/issues/34061) 7578*da0073e9SAndroid Build Coastguard Worker for n in range(250, 551, 100): 7579*da0073e9SAndroid Build Coastguard Worker mat = torch.randn(n, n, dtype=dtype, device=device) 7580*da0073e9SAndroid Build Coastguard Worker q, _ = torch.qr(mat) 7581*da0073e9SAndroid Build Coastguard Worker ref_det, ref_logabsdet = reference_slogdet(q) 7582*da0073e9SAndroid Build Coastguard Worker test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal') 7583*da0073e9SAndroid Build Coastguard Worker 7584*da0073e9SAndroid Build Coastguard Worker def test(M): 7585*da0073e9SAndroid Build Coastguard Worker assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' 7586*da0073e9SAndroid Build Coastguard Worker M = M.to(device) 7587*da0073e9SAndroid Build Coastguard Worker 7588*da0073e9SAndroid Build Coastguard Worker ref_M_sdet, ref_M_logabsdet = reference_slogdet(M) 7589*da0073e9SAndroid Build Coastguard Worker 7590*da0073e9SAndroid Build Coastguard Worker test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic') 7591*da0073e9SAndroid Build Coastguard Worker if ref_M_logabsdet.exp().item() >= 1e-6: # skip singular 7592*da0073e9SAndroid Build Coastguard Worker M_inv = M.inverse() 7593*da0073e9SAndroid Build Coastguard Worker test_single_det(M_inv, reference_slogdet(M_inv), 'inverse') 7594*da0073e9SAndroid Build Coastguard Worker 7595*da0073e9SAndroid Build Coastguard Worker test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose') 7596*da0073e9SAndroid Build Coastguard Worker 7597*da0073e9SAndroid Build Coastguard Worker for x in [0, 2, 4]: 7598*da0073e9SAndroid Build Coastguard Worker for scale in [-2, -0.1, 0, 10]: 7599*da0073e9SAndroid Build Coastguard Worker if scale > 0: 7600*da0073e9SAndroid Build Coastguard Worker target = ref_M_sdet, ref_M_logabsdet + math.log(scale) 7601*da0073e9SAndroid Build Coastguard Worker elif scale == 0: 7602*da0073e9SAndroid Build Coastguard Worker target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) 7603*da0073e9SAndroid Build Coastguard Worker else: 7604*da0073e9SAndroid Build Coastguard Worker target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale) 7605*da0073e9SAndroid Build Coastguard Worker 7606*da0073e9SAndroid Build Coastguard Worker # dim 0 7607*da0073e9SAndroid Build Coastguard Worker M_clone = M.clone() 7608*da0073e9SAndroid Build Coastguard Worker M_clone[:, x] *= scale 7609*da0073e9SAndroid Build Coastguard Worker test_single_det(M_clone, target, 'scale a row') 7610*da0073e9SAndroid Build Coastguard Worker # dim 1 7611*da0073e9SAndroid Build Coastguard Worker M_clone = M.clone() 7612*da0073e9SAndroid Build Coastguard Worker M_clone[x, :] *= scale 7613*da0073e9SAndroid Build Coastguard Worker test_single_det(M_clone, target, 'scale a column') 7614*da0073e9SAndroid Build Coastguard Worker 7615*da0073e9SAndroid Build Coastguard Worker for x1, x2 in [(0, 3), (4, 1), (3, 2)]: 7616*da0073e9SAndroid Build Coastguard Worker assert x1 != x2, 'x1 and x2 needs to be different for this test' 7617*da0073e9SAndroid Build Coastguard Worker target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) 7618*da0073e9SAndroid Build Coastguard Worker # dim 0 7619*da0073e9SAndroid Build Coastguard Worker M_clone = M.clone() 7620*da0073e9SAndroid Build Coastguard Worker M_clone[:, x2] = M_clone[:, x1] 7621*da0073e9SAndroid Build Coastguard Worker test_single_det(M_clone, target, 'two rows are same') 7622*da0073e9SAndroid Build Coastguard Worker # dim 1 7623*da0073e9SAndroid Build Coastguard Worker M_clone = M.clone() 7624*da0073e9SAndroid Build Coastguard Worker M_clone[x2, :] = M_clone[x1, :] 7625*da0073e9SAndroid Build Coastguard Worker test_single_det(M_clone, target, 'two columns are same') 7626*da0073e9SAndroid Build Coastguard Worker 7627*da0073e9SAndroid Build Coastguard Worker for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]: 7628*da0073e9SAndroid Build Coastguard Worker det_scale = scale1 * scale2 * -1 7629*da0073e9SAndroid Build Coastguard Worker if det_scale > 0: 7630*da0073e9SAndroid Build Coastguard Worker target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale) 7631*da0073e9SAndroid Build Coastguard Worker elif det_scale == 0: 7632*da0073e9SAndroid Build Coastguard Worker target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) 7633*da0073e9SAndroid Build Coastguard Worker else: 7634*da0073e9SAndroid Build Coastguard Worker target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale) 7635*da0073e9SAndroid Build Coastguard Worker 7636*da0073e9SAndroid Build Coastguard Worker # dim 0 7637*da0073e9SAndroid Build Coastguard Worker M_clone = M.clone() 7638*da0073e9SAndroid Build Coastguard Worker t = M_clone[:, x1] * scale1 7639*da0073e9SAndroid Build Coastguard Worker M_clone[:, x1] += M_clone[:, x2] * scale2 7640*da0073e9SAndroid Build Coastguard Worker M_clone[:, x2] = t 7641*da0073e9SAndroid Build Coastguard Worker test_single_det(M_clone, target, 'exchanging rows') 7642*da0073e9SAndroid Build Coastguard Worker # dim 1 7643*da0073e9SAndroid Build Coastguard Worker M_clone = M.clone() 7644*da0073e9SAndroid Build Coastguard Worker t = M_clone[x1, :] * scale1 7645*da0073e9SAndroid Build Coastguard Worker M_clone[x1, :] += M_clone[x2, :] * scale2 7646*da0073e9SAndroid Build Coastguard Worker M_clone[x2, :] = t 7647*da0073e9SAndroid Build Coastguard Worker test_single_det(M_clone, target, 'exchanging columns') 7648*da0073e9SAndroid Build Coastguard Worker 7649*da0073e9SAndroid Build Coastguard Worker def get_random_mat_scale(n): 7650*da0073e9SAndroid Build Coastguard Worker # For matrices with values i.i.d. with 0 mean, unit variance, and 7651*da0073e9SAndroid Build Coastguard Worker # subexponential tail, we have: 7652*da0073e9SAndroid Build Coastguard Worker # E[log det(A^2)] \approx log((n-1)!) 7653*da0073e9SAndroid Build Coastguard Worker # 7654*da0073e9SAndroid Build Coastguard Worker # Notice: 7655*da0073e9SAndroid Build Coastguard Worker # log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)] 7656*da0073e9SAndroid Build Coastguard Worker # 7657*da0073e9SAndroid Build Coastguard Worker # So: 7658*da0073e9SAndroid Build Coastguard Worker # stddev[det(A)] >= sqrt( (n-1)! ) 7659*da0073e9SAndroid Build Coastguard Worker # 7660*da0073e9SAndroid Build Coastguard Worker # We use this as an intuitive guideline to scale random generated 7661*da0073e9SAndroid Build Coastguard Worker # matrices so our closeness tests can work more robustly: 7662*da0073e9SAndroid Build Coastguard Worker # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n)) 7663*da0073e9SAndroid Build Coastguard Worker # 7664*da0073e9SAndroid Build Coastguard Worker # source: https://arxiv.org/pdf/1112.0752.pdf 7665*da0073e9SAndroid Build Coastguard Worker 7666*da0073e9SAndroid Build Coastguard Worker # TODO: technically we need subexponential distn for this to hold, 7667*da0073e9SAndroid Build Coastguard Worker # but we mostly use gaussian entries below. Consider switching 7668*da0073e9SAndroid Build Coastguard Worker # to Chi-sq if this turns out not stable enough, since Chi-sq 7669*da0073e9SAndroid Build Coastguard Worker # is easy enough to sample from. 7670*da0073e9SAndroid Build Coastguard Worker return math.factorial(n - 1) ** (-1.0 / (2 * n)) 7671*da0073e9SAndroid Build Coastguard Worker 7672*da0073e9SAndroid Build Coastguard Worker for n in [5, 10, 25]: 7673*da0073e9SAndroid Build Coastguard Worker scale = get_random_mat_scale(n) 7674*da0073e9SAndroid Build Coastguard Worker test(torch.randn(n, n, dtype=dtype, device=device) * scale) 7675*da0073e9SAndroid Build Coastguard Worker r = torch.randn(n, n, dtype=dtype, device=device) * scale 7676*da0073e9SAndroid Build Coastguard Worker # symmetric psd 7677*da0073e9SAndroid Build Coastguard Worker test(r.mm(r.t())) 7678*da0073e9SAndroid Build Coastguard Worker # symmetric pd 7679*da0073e9SAndroid Build Coastguard Worker r = torch.randn(n, n, dtype=dtype, device=device) * scale 7680*da0073e9SAndroid Build Coastguard Worker test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6) 7681*da0073e9SAndroid Build Coastguard Worker # symmetric 7682*da0073e9SAndroid Build Coastguard Worker r = torch.randn(n, n, dtype=dtype, device=device) * scale 7683*da0073e9SAndroid Build Coastguard Worker for i in range(n): 7684*da0073e9SAndroid Build Coastguard Worker for j in range(i): 7685*da0073e9SAndroid Build Coastguard Worker r[i, j] = r[j, i] 7686*da0073e9SAndroid Build Coastguard Worker test(r) 7687*da0073e9SAndroid Build Coastguard Worker # non-contiguous 7688*da0073e9SAndroid Build Coastguard Worker test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:]) 7689*da0073e9SAndroid Build Coastguard Worker # det = 0 7690*da0073e9SAndroid Build Coastguard Worker r = torch.randn(n, n, dtype=dtype, device=device) * scale 7691*da0073e9SAndroid Build Coastguard Worker u, s, v = r.svd() 7692*da0073e9SAndroid Build Coastguard Worker if reference_slogdet(u)[0] < 0: 7693*da0073e9SAndroid Build Coastguard Worker u = -u 7694*da0073e9SAndroid Build Coastguard Worker if reference_slogdet(v)[0] < 0: 7695*da0073e9SAndroid Build Coastguard Worker v = -v 7696*da0073e9SAndroid Build Coastguard Worker s[0] *= -1 7697*da0073e9SAndroid Build Coastguard Worker s[-1] = 0 7698*da0073e9SAndroid Build Coastguard Worker test(u.mm(s.diag()).mm(v)) 7699*da0073e9SAndroid Build Coastguard Worker 7700*da0073e9SAndroid Build Coastguard Worker # Small values to test numerical stability. Note that we don't scale 7701*da0073e9SAndroid Build Coastguard Worker # this matrix. 7702*da0073e9SAndroid Build Coastguard Worker r = torch.randn(512, 512, dtype=dtype, device=device) 7703*da0073e9SAndroid Build Coastguard Worker u, s, v = r.svd() 7704*da0073e9SAndroid Build Coastguard Worker s.fill_(1. / (100 * s.numel())) 7705*da0073e9SAndroid Build Coastguard Worker test(u.mm(s.diag()).mm(v)) 7706*da0073e9SAndroid Build Coastguard Worker 7707*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7708*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7709*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 7710*da0073e9SAndroid Build Coastguard Worker def test_det_logdet_slogdet_batched(self, device, dtype): 7711*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix, 7712*da0073e9SAndroid Build Coastguard Worker random_symmetric_pd_matrix, random_square_matrix_of_rank) 7713*da0073e9SAndroid Build Coastguard Worker 7714*da0073e9SAndroid Build Coastguard Worker # mat_chars denotes matrix characteristics 7715*da0073e9SAndroid Build Coastguard Worker # possible values are: sym, sym_psd, sym_pd, sing, non_sym 7716*da0073e9SAndroid Build Coastguard Worker def run_test(matsize, batchdims, mat_chars): 7717*da0073e9SAndroid Build Coastguard Worker num_matrices = reduce(operator.mul, batchdims, 1) 7718*da0073e9SAndroid Build Coastguard Worker list_of_matrices = [] 7719*da0073e9SAndroid Build Coastguard Worker 7720*da0073e9SAndroid Build Coastguard Worker for idx in range(num_matrices): 7721*da0073e9SAndroid Build Coastguard Worker mat_type = idx % len(mat_chars) 7722*da0073e9SAndroid Build Coastguard Worker if mat_chars[mat_type] == 'sym': 7723*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device)) 7724*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'sym_psd': 7725*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device)) 7726*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'sym_pd': 7727*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device)) 7728*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'sing': 7729*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) 7730*da0073e9SAndroid Build Coastguard Worker elif mat_chars[mat_type] == 'non_sing': 7731*da0073e9SAndroid Build Coastguard Worker list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) 7732*da0073e9SAndroid Build Coastguard Worker full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) 7733*da0073e9SAndroid Build Coastguard Worker # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet 7734*da0073e9SAndroid Build Coastguard Worker full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize))) 7735*da0073e9SAndroid Build Coastguard Worker 7736*da0073e9SAndroid Build Coastguard Worker for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]: 7737*da0073e9SAndroid Build Coastguard Worker expected_value = [] 7738*da0073e9SAndroid Build Coastguard Worker actual_value = fn(full_tensor) 7739*da0073e9SAndroid Build Coastguard Worker for full_idx in itertools.product(*(list(range(x)) for x in batchdims)): 7740*da0073e9SAndroid Build Coastguard Worker expected_value.append(fn(full_tensor[full_idx])) 7741*da0073e9SAndroid Build Coastguard Worker 7742*da0073e9SAndroid Build Coastguard Worker if fn == torch.slogdet or fn == torch.linalg.slogdet: 7743*da0073e9SAndroid Build Coastguard Worker sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims) 7744*da0073e9SAndroid Build Coastguard Worker expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims) 7745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sign_value, actual_value[0]) 7746*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_value, actual_value[1]) 7747*da0073e9SAndroid Build Coastguard Worker else: 7748*da0073e9SAndroid Build Coastguard Worker expected_value = torch.stack(expected_value, dim=0).reshape(batchdims) 7749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_value, expected_value) 7750*da0073e9SAndroid Build Coastguard Worker 7751*da0073e9SAndroid Build Coastguard Worker for matsize, batchdims in itertools.product([3, 5], [(3,), (5, 3)]): 7752*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['sym_pd']) 7753*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['sing']) 7754*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['non_sing']) 7755*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) 7756*da0073e9SAndroid Build Coastguard Worker run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) 7757*da0073e9SAndroid Build Coastguard Worker 7758*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7759*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7760*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 7761*da0073e9SAndroid Build Coastguard Worker def test_cholesky_inverse(self, device, dtype): 7762*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 7763*da0073e9SAndroid Build Coastguard Worker 7764*da0073e9SAndroid Build Coastguard Worker def run_test(shape, batch, upper, contiguous): 7765*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 7766*da0073e9SAndroid Build Coastguard Worker if A.numel() > 0 and not contiguous: 7767*da0073e9SAndroid Build Coastguard Worker A = A.mT 7768*da0073e9SAndroid Build Coastguard Worker self.assertFalse(A.is_contiguous()) 7769*da0073e9SAndroid Build Coastguard Worker L = torch.linalg.cholesky(A) 7770*da0073e9SAndroid Build Coastguard Worker expected_inverse = torch.inverse(A) 7771*da0073e9SAndroid Build Coastguard Worker L = L.mH if upper else L 7772*da0073e9SAndroid Build Coastguard Worker actual_inverse = torch.cholesky_inverse(L, upper) 7773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_inverse, expected_inverse) 7774*da0073e9SAndroid Build Coastguard Worker 7775*da0073e9SAndroid Build Coastguard Worker shapes = (0, 3, 5) 7776*da0073e9SAndroid Build Coastguard Worker batches = ((), (0,), (3, ), (2, 2)) 7777*da0073e9SAndroid Build Coastguard Worker for shape, batch, upper, contiguous in list(itertools.product(shapes, batches, (True, False), (True, False))): 7778*da0073e9SAndroid Build Coastguard Worker run_test(shape, batch, upper, contiguous) 7779*da0073e9SAndroid Build Coastguard Worker 7780*da0073e9SAndroid Build Coastguard Worker # check the out= variant 7781*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(3, 2, dtype=dtype, device=device) 7782*da0073e9SAndroid Build Coastguard Worker L = torch.linalg.cholesky(A) 7783*da0073e9SAndroid Build Coastguard Worker 7784*da0073e9SAndroid Build Coastguard Worker # There are two code paths currently for the out= variant 7785*da0073e9SAndroid Build Coastguard Worker # 1. When 'out' tensor is in Fortran (column-major) memory format 7786*da0073e9SAndroid Build Coastguard Worker # then the fast route is taken and the storage is reused directly in the computations 7787*da0073e9SAndroid Build Coastguard Worker # 2. When 'out' tensor is not in Fortran format then a temporary tensor is allocated internally 7788*da0073e9SAndroid Build Coastguard Worker # and the result is copied from the temporary tensor to 'out' tensor 7789*da0073e9SAndroid Build Coastguard Worker 7790*da0073e9SAndroid Build Coastguard Worker # This test checks the first code path 7791*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(A) 7792*da0073e9SAndroid Build Coastguard Worker out_t = out.mT.clone(memory_format=torch.contiguous_format) 7793*da0073e9SAndroid Build Coastguard Worker out = out_t.mT 7794*da0073e9SAndroid Build Coastguard Worker ans = torch.cholesky_inverse(L, out=out) 7795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 7796*da0073e9SAndroid Build Coastguard Worker expected = torch.inverse(A) 7797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 7798*da0073e9SAndroid Build Coastguard Worker 7799*da0073e9SAndroid Build Coastguard Worker # This test checks the second code path 7800*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(A) 7801*da0073e9SAndroid Build Coastguard Worker ans = torch.cholesky_inverse(L, out=out) 7802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 7803*da0073e9SAndroid Build Coastguard Worker expected = torch.inverse(A) 7804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 7805*da0073e9SAndroid Build Coastguard Worker 7806*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 7807*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 7808*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 7809*da0073e9SAndroid Build Coastguard Worker def test_cholesky_inverse_errors_and_warnings(self, device, dtype): 7810*da0073e9SAndroid Build Coastguard Worker # cholesky_inverse requires the input to be at least 2 dimensional tensor 7811*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, device=device, dtype=dtype) 7812*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): 7813*da0073e9SAndroid Build Coastguard Worker torch.cholesky_inverse(a) 7814*da0073e9SAndroid Build Coastguard Worker 7815*da0073e9SAndroid Build Coastguard Worker # cholesky_inverse requires a square matrix 7816*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device, dtype=dtype) 7817*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): 7818*da0073e9SAndroid Build Coastguard Worker torch.cholesky_inverse(a) 7819*da0073e9SAndroid Build Coastguard Worker 7820*da0073e9SAndroid Build Coastguard Worker # if non-empty out tensor with wrong shape is passed a warning is given 7821*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, device=device, dtype=dtype) 7822*da0073e9SAndroid Build Coastguard Worker out = torch.empty(2, 3, device=device, dtype=dtype) 7823*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 7824*da0073e9SAndroid Build Coastguard Worker # Trigger warning 7825*da0073e9SAndroid Build Coastguard Worker torch.cholesky_inverse(a, out=out) 7826*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 7827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 7828*da0073e9SAndroid Build Coastguard Worker self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) 7829*da0073e9SAndroid Build Coastguard Worker 7830*da0073e9SAndroid Build Coastguard Worker # dtypes should be safely castable 7831*da0073e9SAndroid Build Coastguard Worker out = torch.empty(*a.shape, dtype=torch.int, device=device) 7832*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): 7833*da0073e9SAndroid Build Coastguard Worker torch.cholesky_inverse(a, out=out) 7834*da0073e9SAndroid Build Coastguard Worker 7835*da0073e9SAndroid Build Coastguard Worker # device should match 7836*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7837*da0073e9SAndroid Build Coastguard Worker wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' 7838*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=wrong_device, dtype=dtype) 7839*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 7840*da0073e9SAndroid Build Coastguard Worker torch.cholesky_inverse(a, out=out) 7841*da0073e9SAndroid Build Coastguard Worker 7842*da0073e9SAndroid Build Coastguard Worker # cholesky_inverse raises an error for invalid inputs on CPU 7843*da0073e9SAndroid Build Coastguard Worker # for example if at least one diagonal element is zero 7844*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, device=device, dtype=dtype) 7845*da0073e9SAndroid Build Coastguard Worker a[1, 1] = 0 7846*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 7847*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"): 7848*da0073e9SAndroid Build Coastguard Worker torch.cholesky_inverse(a) 7849*da0073e9SAndroid Build Coastguard Worker # cholesky_inverse on GPU does not raise an error for this case 7850*da0073e9SAndroid Build Coastguard Worker elif self.device_type == 'cuda': 7851*da0073e9SAndroid Build Coastguard Worker out = torch.cholesky_inverse(a) 7852*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.isinf().any() or out.isnan().any()) 7853*da0073e9SAndroid Build Coastguard Worker 7854*da0073e9SAndroid Build Coastguard Worker def _select_broadcastable_dims(self, dims_full=None): 7855*da0073e9SAndroid Build Coastguard Worker # select full dimensionality 7856*da0073e9SAndroid Build Coastguard Worker if dims_full is None: 7857*da0073e9SAndroid Build Coastguard Worker dims_full = [] 7858*da0073e9SAndroid Build Coastguard Worker ndims = random.randint(1, 4) 7859*da0073e9SAndroid Build Coastguard Worker dims_full = [random.randint(1, 8) for _ in range(ndims)] 7860*da0073e9SAndroid Build Coastguard Worker else: 7861*da0073e9SAndroid Build Coastguard Worker ndims = len(dims_full) 7862*da0073e9SAndroid Build Coastguard Worker 7863*da0073e9SAndroid Build Coastguard Worker # select actual dimensions for ops: 7864*da0073e9SAndroid Build Coastguard Worker # larger: full ndims, individual sizes may be reduced 7865*da0073e9SAndroid Build Coastguard Worker # smaller: possibly reduced ndims, sizes may be reduced 7866*da0073e9SAndroid Build Coastguard Worker smaller_ndims = random.randint(1, ndims) 7867*da0073e9SAndroid Build Coastguard Worker dims_small = [] 7868*da0073e9SAndroid Build Coastguard Worker dims_large = [] 7869*da0073e9SAndroid Build Coastguard Worker for i in range(ndims - 1, -1, -1): 7870*da0073e9SAndroid Build Coastguard Worker j = random.randint(1, 3) 7871*da0073e9SAndroid Build Coastguard Worker if j == 1: # no reduced singleton dimension 7872*da0073e9SAndroid Build Coastguard Worker ds = dims_full[i] 7873*da0073e9SAndroid Build Coastguard Worker dl = dims_full[i] 7874*da0073e9SAndroid Build Coastguard Worker elif j == 2: # larger may have reduced singleton dimension 7875*da0073e9SAndroid Build Coastguard Worker ds = dims_full[i] 7876*da0073e9SAndroid Build Coastguard Worker dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] 7877*da0073e9SAndroid Build Coastguard Worker elif j == 3: # smaller may have reduced singleton dimension 7878*da0073e9SAndroid Build Coastguard Worker ds = 1 7879*da0073e9SAndroid Build Coastguard Worker dl = dims_full[i] 7880*da0073e9SAndroid Build Coastguard Worker dims_large = [dl] + dims_large 7881*da0073e9SAndroid Build Coastguard Worker if len(dims_small) < smaller_ndims: 7882*da0073e9SAndroid Build Coastguard Worker dims_small = [ds] + dims_small 7883*da0073e9SAndroid Build Coastguard Worker return (dims_small, dims_large, dims_full) 7884*da0073e9SAndroid Build Coastguard Worker 7885*da0073e9SAndroid Build Coastguard Worker def test_broadcast_fused_matmul(self, device): 7886*da0073e9SAndroid Build Coastguard Worker fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] 7887*da0073e9SAndroid Build Coastguard Worker 7888*da0073e9SAndroid Build Coastguard Worker for fn in fns: 7889*da0073e9SAndroid Build Coastguard Worker batch_dim = random.randint(1, 8) 7890*da0073e9SAndroid Build Coastguard Worker n_dim = random.randint(1, 8) 7891*da0073e9SAndroid Build Coastguard Worker m_dim = random.randint(1, 8) 7892*da0073e9SAndroid Build Coastguard Worker p_dim = random.randint(1, 8) 7893*da0073e9SAndroid Build Coastguard Worker 7894*da0073e9SAndroid Build Coastguard Worker def dims_full_for_fn(): 7895*da0073e9SAndroid Build Coastguard Worker if fn == "baddbmm": 7896*da0073e9SAndroid Build Coastguard Worker return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) 7897*da0073e9SAndroid Build Coastguard Worker elif fn == "addbmm": 7898*da0073e9SAndroid Build Coastguard Worker return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) 7899*da0073e9SAndroid Build Coastguard Worker elif fn == "addmm": 7900*da0073e9SAndroid Build Coastguard Worker return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) 7901*da0073e9SAndroid Build Coastguard Worker elif fn == "addmv": 7902*da0073e9SAndroid Build Coastguard Worker return ([n_dim], [n_dim, m_dim], [m_dim]) 7903*da0073e9SAndroid Build Coastguard Worker elif fn == "addr": 7904*da0073e9SAndroid Build Coastguard Worker return ([n_dim, m_dim], [n_dim], [m_dim]) 7905*da0073e9SAndroid Build Coastguard Worker else: 7906*da0073e9SAndroid Build Coastguard Worker raise AssertionError("unknown function") 7907*da0073e9SAndroid Build Coastguard Worker 7908*da0073e9SAndroid Build Coastguard Worker (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() 7909*da0073e9SAndroid Build Coastguard Worker (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) 7910*da0073e9SAndroid Build Coastguard Worker 7911*da0073e9SAndroid Build Coastguard Worker t0_small = torch.randn(*t0_dims_small, device=device).float() 7912*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(*t1_dims, device=device).float() 7913*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn(*t2_dims, device=device).float() 7914*da0073e9SAndroid Build Coastguard Worker 7915*da0073e9SAndroid Build Coastguard Worker t0_full = t0_small.expand(*t0_dims_full).to(device) 7916*da0073e9SAndroid Build Coastguard Worker 7917*da0073e9SAndroid Build Coastguard Worker fntorch = getattr(torch, fn) 7918*da0073e9SAndroid Build Coastguard Worker r0 = fntorch(t0_small, t1, t2) 7919*da0073e9SAndroid Build Coastguard Worker r1 = fntorch(t0_full, t1, t2) 7920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r0, r1) 7921*da0073e9SAndroid Build Coastguard Worker 7922*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.001) 7923*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.001) 7924*da0073e9SAndroid Build Coastguard Worker def test_broadcast_batched_matmul(self, device): 7925*da0073e9SAndroid Build Coastguard Worker n_dim = random.randint(1, 8) 7926*da0073e9SAndroid Build Coastguard Worker m_dim = random.randint(1, 8) 7927*da0073e9SAndroid Build Coastguard Worker p_dim = random.randint(1, 8) 7928*da0073e9SAndroid Build Coastguard Worker full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))] 7929*da0073e9SAndroid Build Coastguard Worker (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims) 7930*da0073e9SAndroid Build Coastguard Worker 7931*da0073e9SAndroid Build Coastguard Worker def verify_batched_matmul(full_lhs, one_dimensional): 7932*da0073e9SAndroid Build Coastguard Worker if not one_dimensional: 7933*da0073e9SAndroid Build Coastguard Worker lhs_dims = [n_dim, m_dim] 7934*da0073e9SAndroid Build Coastguard Worker rhs_dims = [m_dim, p_dim] 7935*da0073e9SAndroid Build Coastguard Worker result_dims = [n_dim, p_dim] 7936*da0073e9SAndroid Build Coastguard Worker else: 7937*da0073e9SAndroid Build Coastguard Worker lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim] 7938*da0073e9SAndroid Build Coastguard Worker rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim] 7939*da0073e9SAndroid Build Coastguard Worker result_dims = [n_dim] if full_lhs else [p_dim] 7940*da0073e9SAndroid Build Coastguard Worker 7941*da0073e9SAndroid Build Coastguard Worker lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim] 7942*da0073e9SAndroid Build Coastguard Worker rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1] 7943*da0073e9SAndroid Build Coastguard Worker full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims 7944*da0073e9SAndroid Build Coastguard Worker dim0_dims = rhs_dims if full_lhs else lhs_dims 7945*da0073e9SAndroid Build Coastguard Worker small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims) 7946*da0073e9SAndroid Build Coastguard Worker 7947*da0073e9SAndroid Build Coastguard Worker small = torch.randn(*(small_dims), device=device).float() 7948*da0073e9SAndroid Build Coastguard Worker dim0 = torch.randn(*(dim0_dims), device=device).float() 7949*da0073e9SAndroid Build Coastguard Worker full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float() 7950*da0073e9SAndroid Build Coastguard Worker if not one_dimensional: 7951*da0073e9SAndroid Build Coastguard Worker (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,)) 7952*da0073e9SAndroid Build Coastguard Worker else: 7953*da0073e9SAndroid Build Coastguard Worker (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,)) 7954*da0073e9SAndroid Build Coastguard Worker 7955*da0073e9SAndroid Build Coastguard Worker def maybe_squeeze_result(l, r, result): 7956*da0073e9SAndroid Build Coastguard Worker if len(lhs_dims) == 1 and l.dim() != 1: 7957*da0073e9SAndroid Build Coastguard Worker return result.squeeze(-2) 7958*da0073e9SAndroid Build Coastguard Worker elif len(rhs_dims) == 1 and r.dim() != 1: 7959*da0073e9SAndroid Build Coastguard Worker return result.squeeze(-1) 7960*da0073e9SAndroid Build Coastguard Worker else: 7961*da0073e9SAndroid Build Coastguard Worker return result 7962*da0073e9SAndroid Build Coastguard Worker 7963*da0073e9SAndroid Build Coastguard Worker for lhs in lhsTensors: 7964*da0073e9SAndroid Build Coastguard Worker lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims))) 7965*da0073e9SAndroid Build Coastguard Worker lhs_expanded_matmul_fn = lhs_expanded.matmul 7966*da0073e9SAndroid Build Coastguard Worker for rhs in rhsTensors: 7967*da0073e9SAndroid Build Coastguard Worker rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)). 7968*da0073e9SAndroid Build Coastguard Worker expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims)))) 7969*da0073e9SAndroid Build Coastguard Worker truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded)) 7970*da0073e9SAndroid Build Coastguard Worker for l in (lhs, lhs_expanded): 7971*da0073e9SAndroid Build Coastguard Worker for r in (rhs, rhs_expanded): 7972*da0073e9SAndroid Build Coastguard Worker l_matmul_fn = l.matmul 7973*da0073e9SAndroid Build Coastguard Worker result = maybe_squeeze_result(l, r, l_matmul_fn(r)) 7974*da0073e9SAndroid Build Coastguard Worker self.assertEqual(truth, result) 7975*da0073e9SAndroid Build Coastguard Worker # test torch.matmul function as well 7976*da0073e9SAndroid Build Coastguard Worker torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r)) 7977*da0073e9SAndroid Build Coastguard Worker self.assertEqual(truth, torch_result) 7978*da0073e9SAndroid Build Coastguard Worker # test torch.matmul with out 7979*da0073e9SAndroid Build Coastguard Worker out = torch.zeros_like(torch_result) 7980*da0073e9SAndroid Build Coastguard Worker torch.matmul(l, r, out=out) 7981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(truth, maybe_squeeze_result(l, r, out)) 7982*da0073e9SAndroid Build Coastguard Worker 7983*da0073e9SAndroid Build Coastguard Worker # compare to bmm 7984*da0073e9SAndroid Build Coastguard Worker bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims), 7985*da0073e9SAndroid Build Coastguard Worker rhs_expanded.contiguous().view(-1, *rhs_mat_dims))) 7986*da0073e9SAndroid Build Coastguard Worker self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims)) 7987*da0073e9SAndroid Build Coastguard Worker 7988*da0073e9SAndroid Build Coastguard Worker for indices in itertools.product((True, False), repeat=2): 7989*da0073e9SAndroid Build Coastguard Worker verify_batched_matmul(*indices) 7990*da0073e9SAndroid Build Coastguard Worker 7991*da0073e9SAndroid Build Coastguard Worker def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype): 7992*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 7993*da0073e9SAndroid Build Coastguard Worker make_A = partial(make_fullrank, device=device, dtype=dtype) 7994*da0073e9SAndroid Build Coastguard Worker 7995*da0073e9SAndroid Build Coastguard Worker b = torch.randn(*b_dims, dtype=dtype, device=device) 7996*da0073e9SAndroid Build Coastguard Worker A = make_A(*A_dims) 7997*da0073e9SAndroid Build Coastguard Worker LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A) 7998*da0073e9SAndroid Build Coastguard Worker self.assertEqual(info, torch.zeros_like(info)) 7999*da0073e9SAndroid Build Coastguard Worker return b, A, LU_data, LU_pivots 8000*da0073e9SAndroid Build Coastguard Worker 8001*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8002*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 8003*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8004*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 8005*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 8006*da0073e9SAndroid Build Coastguard Worker def test_lu_solve(self, device, dtype): 8007*da0073e9SAndroid Build Coastguard Worker def sub_test(pivot): 8008*da0073e9SAndroid Build Coastguard Worker for k, n in zip([2, 3, 5], [3, 5, 7]): 8009*da0073e9SAndroid Build Coastguard Worker b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype) 8010*da0073e9SAndroid Build Coastguard Worker x = torch.lu_solve(b, LU_data, LU_pivots) 8011*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, np.matmul(A.cpu(), x.cpu())) 8012*da0073e9SAndroid Build Coastguard Worker 8013*da0073e9SAndroid Build Coastguard Worker sub_test(True) 8014*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8015*da0073e9SAndroid Build Coastguard Worker sub_test(False) 8016*da0073e9SAndroid Build Coastguard Worker 8017*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8018*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 8019*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8020*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 8021*da0073e9SAndroid Build Coastguard Worker torch.float64: 1e-8, torch.complex128: 1e-8}) 8022*da0073e9SAndroid Build Coastguard Worker def test_lu_solve_batched(self, device, dtype): 8023*da0073e9SAndroid Build Coastguard Worker def sub_test(pivot): 8024*da0073e9SAndroid Build Coastguard Worker def lu_solve_batch_test_helper(A_dims, b_dims, pivot): 8025*da0073e9SAndroid Build Coastguard Worker b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype) 8026*da0073e9SAndroid Build Coastguard Worker x_exp_list = [] 8027*da0073e9SAndroid Build Coastguard Worker for i in range(b_dims[0]): 8028*da0073e9SAndroid Build Coastguard Worker x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) 8029*da0073e9SAndroid Build Coastguard Worker x_exp = torch.stack(x_exp_list) # Stacked output 8030*da0073e9SAndroid Build Coastguard Worker x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output 8031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_exp, x_act) # Equality check 8032*da0073e9SAndroid Build Coastguard Worker Ax = np.matmul(A.cpu(), x_act.cpu()) 8033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, Ax) 8034*da0073e9SAndroid Build Coastguard Worker 8035*da0073e9SAndroid Build Coastguard Worker for batchsize in [1, 3, 4]: 8036*da0073e9SAndroid Build Coastguard Worker lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot) 8037*da0073e9SAndroid Build Coastguard Worker 8038*da0073e9SAndroid Build Coastguard Worker # Tests tensors with 0 elements 8039*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 0, 3, dtype=dtype, device=device) 8040*da0073e9SAndroid Build Coastguard Worker A = torch.randn(3, 0, 0, dtype=dtype, device=device) 8041*da0073e9SAndroid Build Coastguard Worker LU_data, LU_pivots = torch.linalg.lu_factor(A) 8042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) 8043*da0073e9SAndroid Build Coastguard Worker 8044*da0073e9SAndroid Build Coastguard Worker sub_test(True) 8045*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8046*da0073e9SAndroid Build Coastguard Worker sub_test(False) 8047*da0073e9SAndroid Build Coastguard Worker 8048*da0073e9SAndroid Build Coastguard Worker @slowTest 8049*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8050*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 8051*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8052*da0073e9SAndroid Build Coastguard Worker def test_lu_solve_batched_many_batches(self, device, dtype): 8053*da0073e9SAndroid Build Coastguard Worker def run_test(A_dims, b_dims): 8054*da0073e9SAndroid Build Coastguard Worker b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) 8055*da0073e9SAndroid Build Coastguard Worker x = torch.lu_solve(b, LU_data, LU_pivots) 8056*da0073e9SAndroid Build Coastguard Worker Ax = torch.matmul(A, x) 8057*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Ax, b.expand_as(Ax)) 8058*da0073e9SAndroid Build Coastguard Worker 8059*da0073e9SAndroid Build Coastguard Worker run_test((65536, 5, 5), (65536, 5, 10)) 8060*da0073e9SAndroid Build Coastguard Worker run_test((262144, 5, 5), (262144, 5, 10)) 8061*da0073e9SAndroid Build Coastguard Worker 8062*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8063*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 8064*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8065*da0073e9SAndroid Build Coastguard Worker def test_lu_solve_batched_broadcasting(self, device, dtype): 8066*da0073e9SAndroid Build Coastguard Worker make_fullrank = make_fullrank_matrices_with_distinct_singular_values 8067*da0073e9SAndroid Build Coastguard Worker make_A = partial(make_fullrank, device=device, dtype=dtype) 8068*da0073e9SAndroid Build Coastguard Worker 8069*da0073e9SAndroid Build Coastguard Worker def run_test(A_dims, b_dims, pivot=True): 8070*da0073e9SAndroid Build Coastguard Worker A_matrix_size = A_dims[-1] 8071*da0073e9SAndroid Build Coastguard Worker A_batch_dims = A_dims[:-2] 8072*da0073e9SAndroid Build Coastguard Worker A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size) 8073*da0073e9SAndroid Build Coastguard Worker b = make_tensor(b_dims, dtype=dtype, device=device) 8074*da0073e9SAndroid Build Coastguard Worker x_exp = np.linalg.solve(A.cpu(), b.cpu()) 8075*da0073e9SAndroid Build Coastguard Worker LU_data, LU_pivots = torch.linalg.lu_factor(A) 8076*da0073e9SAndroid Build Coastguard Worker x = torch.lu_solve(b, LU_data, LU_pivots) 8077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_exp) 8078*da0073e9SAndroid Build Coastguard Worker 8079*da0073e9SAndroid Build Coastguard Worker # test against numpy.linalg.solve 8080*da0073e9SAndroid Build Coastguard Worker run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting 8081*da0073e9SAndroid Build Coastguard Worker run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b 8082*da0073e9SAndroid Build Coastguard Worker run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A 8083*da0073e9SAndroid Build Coastguard Worker run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b 8084*da0073e9SAndroid Build Coastguard Worker 8085*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8086*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 8087*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8088*da0073e9SAndroid Build Coastguard Worker # this tests https://github.com/pytorch/pytorch/issues/36921 8089*da0073e9SAndroid Build Coastguard Worker def test_lu_solve_large_matrices(self, device, dtype): 8090*da0073e9SAndroid Build Coastguard Worker def run_test(A_dims, b_dims): 8091*da0073e9SAndroid Build Coastguard Worker b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) 8092*da0073e9SAndroid Build Coastguard Worker x = torch.lu_solve(b, LU_data, LU_pivots) 8093*da0073e9SAndroid Build Coastguard Worker Ax = torch.matmul(A, x) 8094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Ax, b.expand_as(Ax)) 8095*da0073e9SAndroid Build Coastguard Worker 8096*da0073e9SAndroid Build Coastguard Worker run_test((1, 1), (1, 1, 1025)) 8097*da0073e9SAndroid Build Coastguard Worker 8098*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 8099*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8100*da0073e9SAndroid Build Coastguard Worker def test_pca_lowrank(self, device): 8101*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix 8102*da0073e9SAndroid Build Coastguard Worker 8103*da0073e9SAndroid Build Coastguard Worker dtype = torch.double 8104*da0073e9SAndroid Build Coastguard Worker 8105*da0073e9SAndroid Build Coastguard Worker def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options): 8106*da0073e9SAndroid Build Coastguard Worker density = options.pop('density', 1) 8107*da0073e9SAndroid Build Coastguard Worker use_svd_lowrank = options.pop('use_svd_lowrank', False) 8108*da0073e9SAndroid Build Coastguard Worker if isinstance(matrix_size, int): 8109*da0073e9SAndroid Build Coastguard Worker rows = columns = matrix_size 8110*da0073e9SAndroid Build Coastguard Worker else: 8111*da0073e9SAndroid Build Coastguard Worker rows, columns = matrix_size 8112*da0073e9SAndroid Build Coastguard Worker if density == 1: 8113*da0073e9SAndroid Build Coastguard Worker a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) 8114*da0073e9SAndroid Build Coastguard Worker a = a_input 8115*da0073e9SAndroid Build Coastguard Worker else: 8116*da0073e9SAndroid Build Coastguard Worker a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) 8117*da0073e9SAndroid Build Coastguard Worker a = a_input.to_dense() 8118*da0073e9SAndroid Build Coastguard Worker 8119*da0073e9SAndroid Build Coastguard Worker if use_svd_lowrank: 8120*da0073e9SAndroid Build Coastguard Worker m = a_input.mean(dim=-2, keepdim=True) 8121*da0073e9SAndroid Build Coastguard Worker u, s, v = pca(a_input, q=guess_rank, M=m, **options) 8122*da0073e9SAndroid Build Coastguard Worker else: 8123*da0073e9SAndroid Build Coastguard Worker u, s, v = pca(a_input, q=guess_rank, **options) 8124*da0073e9SAndroid Build Coastguard Worker 8125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.shape[-1], guess_rank) 8126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(u.shape[-2], rows) 8127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(u.shape[-1], guess_rank) 8128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.shape[-1], guess_rank) 8129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.shape[-2], columns) 8130*da0073e9SAndroid Build Coastguard Worker 8131*da0073e9SAndroid Build Coastguard Worker A1 = u.matmul(s.diag_embed()).matmul(v.mT) 8132*da0073e9SAndroid Build Coastguard Worker ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device) 8133*da0073e9SAndroid Build Coastguard Worker c = a.sum(axis=-2) / rows 8134*da0073e9SAndroid Build Coastguard Worker c = c.reshape(batches + (1, columns)) 8135*da0073e9SAndroid Build Coastguard Worker A2 = a - ones_m1.matmul(c) 8136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A1, A2) 8137*da0073e9SAndroid Build Coastguard Worker 8138*da0073e9SAndroid Build Coastguard Worker if density == 1: 8139*da0073e9SAndroid Build Coastguard Worker # actual rank is known only for dense input 8140*da0073e9SAndroid Build Coastguard Worker detect_rank = (s.abs() > 1e-5).sum(axis=-1) 8141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank) 8142*da0073e9SAndroid Build Coastguard Worker S = torch.linalg.svdvals(A2) 8143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s[..., :actual_rank], S[..., :actual_rank]) 8144*da0073e9SAndroid Build Coastguard Worker 8145*da0073e9SAndroid Build Coastguard Worker all_batches = [(), (1,), (3,), (2, 3)] 8146*da0073e9SAndroid Build Coastguard Worker for actual_rank, size, all_batches in [ # noqa: B020 8147*da0073e9SAndroid Build Coastguard Worker (2, (17, 4), all_batches), 8148*da0073e9SAndroid Build Coastguard Worker (2, (100, 4), all_batches), 8149*da0073e9SAndroid Build Coastguard Worker (6, (100, 40), all_batches), 8150*da0073e9SAndroid Build Coastguard Worker (12, (1000, 1000), [()]), 8151*da0073e9SAndroid Build Coastguard Worker ]: 8152*da0073e9SAndroid Build Coastguard Worker for batches in all_batches: 8153*da0073e9SAndroid Build Coastguard Worker for guess_rank in [ 8154*da0073e9SAndroid Build Coastguard Worker actual_rank, 8155*da0073e9SAndroid Build Coastguard Worker actual_rank + 2, 8156*da0073e9SAndroid Build Coastguard Worker actual_rank + 6, 8157*da0073e9SAndroid Build Coastguard Worker ]: 8158*da0073e9SAndroid Build Coastguard Worker if guess_rank <= min(*size): 8159*da0073e9SAndroid Build Coastguard Worker run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank) 8160*da0073e9SAndroid Build Coastguard Worker run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank) 8161*da0073e9SAndroid Build Coastguard Worker run_subtest(guess_rank, actual_rank, size, batches, device, torch.svd_lowrank, use_svd_lowrank=True) 8162*da0073e9SAndroid Build Coastguard Worker run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.svd_lowrank, use_svd_lowrank=True) 8163*da0073e9SAndroid Build Coastguard Worker 8164*da0073e9SAndroid Build Coastguard Worker # sparse input 8165*da0073e9SAndroid Build Coastguard Worker for guess_rank, size in [ 8166*da0073e9SAndroid Build Coastguard Worker (4, (17, 4)), (4, (4, 17)), (16, (17, 17)), 8167*da0073e9SAndroid Build Coastguard Worker (21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]: 8168*da0073e9SAndroid Build Coastguard Worker for density in [0.005, 0.1]: 8169*da0073e9SAndroid Build Coastguard Worker run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density) 8170*da0073e9SAndroid Build Coastguard Worker 8171*da0073e9SAndroid Build Coastguard Worker # jitting support 8172*da0073e9SAndroid Build Coastguard Worker jitted = torch.jit.script(torch.pca_lowrank) 8173*da0073e9SAndroid Build Coastguard Worker guess_rank, actual_rank, size, batches = 2, 2, (17, 4), () 8174*da0073e9SAndroid Build Coastguard Worker run_subtest(guess_rank, actual_rank, size, batches, device, jitted) 8175*da0073e9SAndroid Build Coastguard Worker 8176*da0073e9SAndroid Build Coastguard Worker # Ensure that nuclear_norm's out variant gives the same result as the non-out 8177*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8178*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 8179*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8180*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 8181*da0073e9SAndroid Build Coastguard Worker def test_nuclear_norm_out(self, device, dtype): 8182*da0073e9SAndroid Build Coastguard Worker test_cases = [ 8183*da0073e9SAndroid Build Coastguard Worker # input size, dim 8184*da0073e9SAndroid Build Coastguard Worker ((25, 25), None), 8185*da0073e9SAndroid Build Coastguard Worker ((25, 25), (0, 1)), 8186*da0073e9SAndroid Build Coastguard Worker ((25, 25), (1, 0)), 8187*da0073e9SAndroid Build Coastguard Worker ((25, 25, 25), (2, 0)), 8188*da0073e9SAndroid Build Coastguard Worker ((25, 25, 25), (0, 1)), 8189*da0073e9SAndroid Build Coastguard Worker ] 8190*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 8191*da0073e9SAndroid Build Coastguard Worker for input_size, dim in test_cases: 8192*da0073e9SAndroid Build Coastguard Worker msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}' 8193*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*input_size, device=device, dtype=dtype) 8194*da0073e9SAndroid Build Coastguard Worker result_out = torch.empty(0, device=device, dtype=dtype) 8195*da0073e9SAndroid Build Coastguard Worker if dim is None: 8196*da0073e9SAndroid Build Coastguard Worker result = torch.nuclear_norm(x, keepdim=keepdim) 8197*da0073e9SAndroid Build Coastguard Worker torch.nuclear_norm(x, keepdim=keepdim, out=result_out) 8198*da0073e9SAndroid Build Coastguard Worker else: 8199*da0073e9SAndroid Build Coastguard Worker result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim) 8200*da0073e9SAndroid Build Coastguard Worker torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out) 8201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_out, msg=msg) 8202*da0073e9SAndroid Build Coastguard Worker 8203*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagmaAndNoCusolver 8204*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8205*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8206*da0073e9SAndroid Build Coastguard Worker def test_geqrf(self, device, dtype): 8207*da0073e9SAndroid Build Coastguard Worker 8208*da0073e9SAndroid Build Coastguard Worker def run_test(shape): 8209*da0073e9SAndroid Build Coastguard Worker # numpy.linalg.qr with mode = 'raw' computes the same operation as torch.geqrf 8210*da0073e9SAndroid Build Coastguard Worker # so this test compares against that function 8211*da0073e9SAndroid Build Coastguard Worker A = make_tensor(shape, dtype=dtype, device=device) 8212*da0073e9SAndroid Build Coastguard Worker 8213*da0073e9SAndroid Build Coastguard Worker # numpy.linalg.qr doesn't work with batched input 8214*da0073e9SAndroid Build Coastguard Worker m, n = A.shape[-2:] 8215*da0073e9SAndroid Build Coastguard Worker tau_size = "n" if m > n else "m" 8216*da0073e9SAndroid Build Coastguard Worker np_dtype = A.cpu().numpy().dtype 8217*da0073e9SAndroid Build Coastguard Worker ot = [np_dtype, np_dtype] 8218*da0073e9SAndroid Build Coastguard Worker numpy_geqrf_batched = np.vectorize( 8219*da0073e9SAndroid Build Coastguard Worker lambda x: np.linalg.qr(x, mode='raw'), 8220*da0073e9SAndroid Build Coastguard Worker otypes=ot, 8221*da0073e9SAndroid Build Coastguard Worker signature=f'(m,n)->(n,m),({tau_size})') 8222*da0073e9SAndroid Build Coastguard Worker 8223*da0073e9SAndroid Build Coastguard Worker expected = numpy_geqrf_batched(A.cpu()) 8224*da0073e9SAndroid Build Coastguard Worker actual = torch.geqrf(A) 8225*da0073e9SAndroid Build Coastguard Worker 8226*da0073e9SAndroid Build Coastguard Worker # numpy.linalg.qr returns transposed result 8227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected[0].swapaxes(-2, -1), actual[0]) 8228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected[1], actual[1]) 8229*da0073e9SAndroid Build Coastguard Worker 8230*da0073e9SAndroid Build Coastguard Worker batches = [(), (0, ), (2, ), (2, 1)] 8231*da0073e9SAndroid Build Coastguard Worker ns = [5, 2, 0] 8232*da0073e9SAndroid Build Coastguard Worker for batch, (m, n) in product(batches, product(ns, ns)): 8233*da0073e9SAndroid Build Coastguard Worker run_test((*batch, m, n)) 8234*da0073e9SAndroid Build Coastguard Worker 8235*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 8236*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8237*da0073e9SAndroid Build Coastguard Worker def test_lapack_empty(self, device): 8238*da0073e9SAndroid Build Coastguard Worker # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here. 8239*da0073e9SAndroid Build Coastguard Worker # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although 8240*da0073e9SAndroid Build Coastguard Worker # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing" 8241*da0073e9SAndroid Build Coastguard Worker # (e.g. lu). We often name our functions identically to the lapack function, so it will take work 8242*da0073e9SAndroid Build Coastguard Worker # to name / migrate-to better wrappers. 8243*da0073e9SAndroid Build Coastguard Worker def fn(torchfn, *args): 8244*da0073e9SAndroid Build Coastguard Worker return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape 8245*da0073e9SAndroid Build Coastguard Worker for shape in args)) 8246*da0073e9SAndroid Build Coastguard Worker 8247*da0073e9SAndroid Build Coastguard Worker # inverse, pinverse 8248*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape) 8249*da0073e9SAndroid Build Coastguard Worker self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape) 8250*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape) 8251*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape) 8252*da0073e9SAndroid Build Coastguard Worker 8253*da0073e9SAndroid Build Coastguard Worker # det, logdet, slogdet 8254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0))) 8255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0))) 8256*da0073e9SAndroid Build Coastguard Worker self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), 8257*da0073e9SAndroid Build Coastguard Worker fn(torch.slogdet, (0, 0))) 8258*da0073e9SAndroid Build Coastguard Worker 8259*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 8260*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 8261*da0073e9SAndroid Build Coastguard Worker def test_tensordot(self, device): 8262*da0073e9SAndroid Build Coastguard Worker a = torch.arange(60., device=device).reshape(3, 4, 5) 8263*da0073e9SAndroid Build Coastguard Worker b = torch.arange(24., device=device).reshape(4, 3, 2) 8264*da0073e9SAndroid Build Coastguard Worker c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() 8265*da0073e9SAndroid Build Coastguard Worker cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), 8266*da0073e9SAndroid Build Coastguard Worker axes=([1, 0], [0, 1]))) 8267*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cn) 8268*da0073e9SAndroid Build Coastguard Worker 8269*da0073e9SAndroid Build Coastguard Worker cout = torch.zeros((5, 2), device=device) 8270*da0073e9SAndroid Build Coastguard Worker torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() 8271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cout) 8272*da0073e9SAndroid Build Coastguard Worker 8273*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4, 5, device=device) 8274*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4, 5, 6, 7, device=device) 8275*da0073e9SAndroid Build Coastguard Worker c = torch.tensordot(a, b, dims=2).cpu() 8276*da0073e9SAndroid Build Coastguard Worker cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), 8277*da0073e9SAndroid Build Coastguard Worker axes=2)) 8278*da0073e9SAndroid Build Coastguard Worker 8279*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): 8280*da0073e9SAndroid Build Coastguard Worker torch.tensordot(a, b, dims=-1) 8281*da0073e9SAndroid Build Coastguard Worker 8282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cn) 8283*da0073e9SAndroid Build Coastguard Worker c = torch.tensordot(a, b).cpu() 8284*da0073e9SAndroid Build Coastguard Worker cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) 8285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cn) 8286*da0073e9SAndroid Build Coastguard Worker 8287*da0073e9SAndroid Build Coastguard Worker a = torch.tensordot(torch.tensor(0.), torch.tensor(0.), 0) 8288*da0073e9SAndroid Build Coastguard Worker an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) 8289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, an) 8290*da0073e9SAndroid Build Coastguard Worker 8291*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 8292*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 8293*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8294*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("flaky, needs investigation") 8295*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8296*da0073e9SAndroid Build Coastguard Worker def test_ldl_factor(self, device, dtype): 8297*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 8298*da0073e9SAndroid Build Coastguard Worker 8299*da0073e9SAndroid Build Coastguard Worker def run_test(shape, batch, hermitian): 8300*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 8301*da0073e9SAndroid Build Coastguard Worker actual_factors, actual_pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian) 8302*da0073e9SAndroid Build Coastguard Worker actual_L = torch.tril(actual_factors, diagonal=-1) 8303*da0073e9SAndroid Build Coastguard Worker actual_L.diagonal(0, -2, -1).fill_(1.0) 8304*da0073e9SAndroid Build Coastguard Worker 8305*da0073e9SAndroid Build Coastguard Worker # This test is designed only for inputs with 1x1 block diagonal matrix D. 8306*da0073e9SAndroid Build Coastguard Worker # That is for positive definite input matrices, the pivots tensor is always > 0. 8307*da0073e9SAndroid Build Coastguard Worker # If negative pivots are encountered, it means that the input matrix is not positive definite. 8308*da0073e9SAndroid Build Coastguard Worker # And matrix D is a 2x2 block diagonal matrix. 8309*da0073e9SAndroid Build Coastguard Worker self.assertTrue((actual_pivots > 0).all()) 8310*da0073e9SAndroid Build Coastguard Worker 8311*da0073e9SAndroid Build Coastguard Worker # Construct a 1x1 block diagonal matrix D from factors. 8312*da0073e9SAndroid Build Coastguard Worker actual_D = torch.diag_embed(actual_factors.diagonal(0, -2, -1)) 8313*da0073e9SAndroid Build Coastguard Worker 8314*da0073e9SAndroid Build Coastguard Worker def T(x): 8315*da0073e9SAndroid Build Coastguard Worker return x.mH if hermitian else x.mT 8316*da0073e9SAndroid Build Coastguard Worker A_reconstructed = actual_L @ actual_D @ T(actual_L) 8317*da0073e9SAndroid Build Coastguard Worker 8318*da0073e9SAndroid Build Coastguard Worker def symmetric(A): 8319*da0073e9SAndroid Build Coastguard Worker return A.tril() + A.tril(-1).mT 8320*da0073e9SAndroid Build Coastguard Worker 8321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(symmetric(A) if not hermitian else A, A_reconstructed) 8322*da0073e9SAndroid Build Coastguard Worker 8323*da0073e9SAndroid Build Coastguard Worker # Now test against SciPy implementation 8324*da0073e9SAndroid Build Coastguard Worker if TEST_SCIPY: 8325*da0073e9SAndroid Build Coastguard Worker from scipy.linalg import ldl as scipy_ldl 8326*da0073e9SAndroid Build Coastguard Worker A_np = A.cpu().numpy() 8327*da0073e9SAndroid Build Coastguard Worker np_dtype = A_np.dtype 8328*da0073e9SAndroid Build Coastguard Worker scipy_ldl_batched = np.vectorize( 8329*da0073e9SAndroid Build Coastguard Worker lambda x: scipy_ldl(x, hermitian=hermitian, lower=True), 8330*da0073e9SAndroid Build Coastguard Worker otypes=[np_dtype, np_dtype, np.dtype('int64')], 8331*da0073e9SAndroid Build Coastguard Worker signature='(m,m)->(m,m),(m,m),(m)') 8332*da0073e9SAndroid Build Coastguard Worker 8333*da0073e9SAndroid Build Coastguard Worker expected = scipy_ldl_batched(A_np) 8334*da0073e9SAndroid Build Coastguard Worker expected_L, expected_D, expected_pivots = expected 8335*da0073e9SAndroid Build Coastguard Worker 8336*da0073e9SAndroid Build Coastguard Worker if expected_pivots.ndim > 1: 8337*da0073e9SAndroid Build Coastguard Worker permuted_expected_L = np.stack( 8338*da0073e9SAndroid Build Coastguard Worker [expected_L[i][expected_pivots[i], :] for i in range(expected_pivots.shape[0])] 8339*da0073e9SAndroid Build Coastguard Worker ) 8340*da0073e9SAndroid Build Coastguard Worker else: 8341*da0073e9SAndroid Build Coastguard Worker permuted_expected_L = expected_L[expected_pivots, :] 8342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_L, permuted_expected_L) 8343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_D, expected_D) 8344*da0073e9SAndroid Build Coastguard Worker else: 8345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_factors.shape, A.shape) 8346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_pivots.shape, A.shape[:-1]) 8347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(info.shape, A.shape[:-2]) 8348*da0073e9SAndroid Build Coastguard Worker 8349*da0073e9SAndroid Build Coastguard Worker # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+ 8350*da0073e9SAndroid Build Coastguard Worker magma_254_available = self.device_type == 'cuda' and _get_magma_version() >= (2, 5, 4) 8351*da0073e9SAndroid Build Coastguard Worker hermitians = (True, False) if dtype.is_complex and (self.device_type == 'cpu' or magma_254_available) else (False,) 8352*da0073e9SAndroid Build Coastguard Worker 8353*da0073e9SAndroid Build Coastguard Worker shapes = (5,) 8354*da0073e9SAndroid Build Coastguard Worker batches = ((), (4,),) 8355*da0073e9SAndroid Build Coastguard Worker for shape, batch, hermitian in itertools.product(shapes, batches, hermitians): 8356*da0073e9SAndroid Build Coastguard Worker run_test(shape, batch, hermitian) 8357*da0073e9SAndroid Build Coastguard Worker 8358*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 8359*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 8360*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoLapack 8361*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 8362*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(_get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1") 8363*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 8364*da0073e9SAndroid Build Coastguard Worker def test_ldl_solve(self, device, dtype): 8365*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 8366*da0073e9SAndroid Build Coastguard Worker 8367*da0073e9SAndroid Build Coastguard Worker def run_test(shape, batch, nrhs, hermitian): 8368*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) 8369*da0073e9SAndroid Build Coastguard Worker B = make_tensor((*A.shape[:-1], nrhs), dtype=dtype, device=device) 8370*da0073e9SAndroid Build Coastguard Worker factors, pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian) 8371*da0073e9SAndroid Build Coastguard Worker X = torch.linalg.ldl_solve(factors, pivots, B, hermitian=hermitian) 8372*da0073e9SAndroid Build Coastguard Worker 8373*da0073e9SAndroid Build Coastguard Worker def symmetric(A): 8374*da0073e9SAndroid Build Coastguard Worker return A.tril() + A.tril(-1).mT 8375*da0073e9SAndroid Build Coastguard Worker 8376*da0073e9SAndroid Build Coastguard Worker # verify A @ X == B 8377*da0073e9SAndroid Build Coastguard Worker expected_B = symmetric(A) @ X if not hermitian else A @ X 8378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(B, expected_B) 8379*da0073e9SAndroid Build Coastguard Worker 8380*da0073e9SAndroid Build Coastguard Worker # hermitian=True is not supported on CUDA yet 8381*da0073e9SAndroid Build Coastguard Worker hermitians = (True, False) if dtype.is_complex and self.device_type == 'cpu' else (False,) 8382*da0073e9SAndroid Build Coastguard Worker 8383*da0073e9SAndroid Build Coastguard Worker shapes = (5,) 8384*da0073e9SAndroid Build Coastguard Worker batches = ((), (4,), (2, 2)) 8385*da0073e9SAndroid Build Coastguard Worker nrhss = (1, 7) 8386*da0073e9SAndroid Build Coastguard Worker for shape, batch, nrhs, hermitian in itertools.product(shapes, batches, nrhss, hermitians): 8387*da0073e9SAndroid Build Coastguard Worker run_test(shape, batch, nrhs, hermitian) 8388*da0073e9SAndroid Build Coastguard Worker 8389*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8390*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoMagma 8391*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNoCusolver 8392*da0073e9SAndroid Build Coastguard Worker @setLinalgBackendsToDefaultFinally 8393*da0073e9SAndroid Build Coastguard Worker def test_preferred_linalg_library(self): 8394*da0073e9SAndroid Build Coastguard Worker # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. 8395*da0073e9SAndroid Build Coastguard Worker x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double) 8396*da0073e9SAndroid Build Coastguard Worker 8397*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_linalg_library('cusolver') 8398*da0073e9SAndroid Build Coastguard Worker out1 = torch.linalg.inv(x) 8399*da0073e9SAndroid Build Coastguard Worker 8400*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_linalg_library('magma') 8401*da0073e9SAndroid Build Coastguard Worker out2 = torch.linalg.inv(x) 8402*da0073e9SAndroid Build Coastguard Worker 8403*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_linalg_library('default') 8404*da0073e9SAndroid Build Coastguard Worker # Although linalg preferred flags doesn't affect CPU currently, 8405*da0073e9SAndroid Build Coastguard Worker # we set this to make sure the flag can switch back to default normally. 8406*da0073e9SAndroid Build Coastguard Worker out_ref = torch.linalg.inv(x.cpu()) 8407*da0073e9SAndroid Build Coastguard Worker 8408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out1.cpu()) 8409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 8410*da0073e9SAndroid Build Coastguard Worker 8411*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8412*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device") 8413*da0073e9SAndroid Build Coastguard Worker @setBlasBackendsToDefaultFinally 8414*da0073e9SAndroid Build Coastguard Worker def test_preferred_blas_library(self): 8415*da0073e9SAndroid Build Coastguard Worker # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. 8416*da0073e9SAndroid Build Coastguard Worker m1 = torch.randint(2, 5, (2048, 2400), device='cuda', dtype=torch.float) 8417*da0073e9SAndroid Build Coastguard Worker m2 = torch.randint(2, 5, (128, 2400), device='cuda', dtype=torch.float) 8418*da0073e9SAndroid Build Coastguard Worker 8419*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_blas_library('cublaslt') 8420*da0073e9SAndroid Build Coastguard Worker out1 = torch.nn.functional.linear(m1, m2) 8421*da0073e9SAndroid Build Coastguard Worker 8422*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.preferred_blas_library('cublas') 8423*da0073e9SAndroid Build Coastguard Worker out2 = torch.nn.functional.linear(m1, m2) 8424*da0073e9SAndroid Build Coastguard Worker 8425*da0073e9SAndroid Build Coastguard Worker # Although blas preferred flags doesn't affect CPU currently, 8426*da0073e9SAndroid Build Coastguard Worker # we set this to make sure the flag can switch back to default normally. 8427*da0073e9SAndroid Build Coastguard Worker out_ref = torch.nn.functional.linear(m1.cpu(), m2.cpu()) 8428*da0073e9SAndroid Build Coastguard Worker 8429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 8430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out2.cpu()) 8431*da0073e9SAndroid Build Coastguard Worker 8432*da0073e9SAndroid Build Coastguard Worker def test_permute_matmul(self): 8433*da0073e9SAndroid Build Coastguard Worker a = torch.ones([2, 5, 24, 24]) 8434*da0073e9SAndroid Build Coastguard Worker b = torch.ones([3, 2, 5, 24, 24]) 8435*da0073e9SAndroid Build Coastguard Worker c = a.permute(0, 1, 3, 2).matmul(b) 8436*da0073e9SAndroid Build Coastguard Worker self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720]) 8437*da0073e9SAndroid Build Coastguard Worker 8438*da0073e9SAndroid Build Coastguard Worker def test_lower_precision_accumulation_with_ref_path(self): 8439*da0073e9SAndroid Build Coastguard Worker # fix https://github.com/pytorch/pytorch/issues/95125 8440*da0073e9SAndroid Build Coastguard Worker # and https://github.com/pytorch/pytorch/issues/83863 8441*da0073e9SAndroid Build Coastguard Worker # for bf16 accumulation in gemm ref path 8442*da0073e9SAndroid Build Coastguard Worker def check_correctness(fn, dtype, *args): 8443*da0073e9SAndroid Build Coastguard Worker expected = fn(*args).to(dtype=dtype) 8444*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 8445*da0073e9SAndroid Build Coastguard Worker def test(): 8446*da0073e9SAndroid Build Coastguard Worker lower_args = (arg.to(dtype=dtype) for arg in args) 8447*da0073e9SAndroid Build Coastguard Worker tmp_result = fn(*lower_args) 8448*da0073e9SAndroid Build Coastguard Worker return tmp_result 8449*da0073e9SAndroid Build Coastguard Worker c = test() 8450*da0073e9SAndroid Build Coastguard Worker assert (torch.all(c == expected)), "Incorrect result with\n" \ 8451*da0073e9SAndroid Build Coastguard Worker f"expected: {expected}\n" \ 8452*da0073e9SAndroid Build Coastguard Worker f"got: {c}\n" 8453*da0073e9SAndroid Build Coastguard Worker # test matmul 8454*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.half]: 8455*da0073e9SAndroid Build Coastguard Worker for transa in [True, False]: 8456*da0073e9SAndroid Build Coastguard Worker for transb in [True, False]: 8457*da0073e9SAndroid Build Coastguard Worker a = torch.ones(300, 300) 8458*da0073e9SAndroid Build Coastguard Worker b = torch.ones(300, 300) 8459*da0073e9SAndroid Build Coastguard Worker if transa: 8460*da0073e9SAndroid Build Coastguard Worker a = a.transpose(0, 1).contiguous().transpose(0, 1) 8461*da0073e9SAndroid Build Coastguard Worker if transb: 8462*da0073e9SAndroid Build Coastguard Worker b = b.transpose(0, 1).contiguous().transpose(0, 1) 8463*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.matmul, dtype, a, b) 8464*da0073e9SAndroid Build Coastguard Worker # test bmm 8465*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, 300) 8466*da0073e9SAndroid Build Coastguard Worker b = torch.ones(1, 300, 1) 8467*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.bmm, torch.bfloat16, a, b) 8468*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.bmm, torch.half, a, b) 8469*da0073e9SAndroid Build Coastguard Worker # test baddbmm 8470*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, 300) 8471*da0073e9SAndroid Build Coastguard Worker b = torch.ones(1, 300, 1) 8472*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 1, 1) 8473*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.baddbmm, torch.bfloat16, c, a, b) 8474*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.baddbmm, torch.half, c, a, b) 8475*da0073e9SAndroid Build Coastguard Worker # test mv/addmv 8476*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.half]: 8477*da0073e9SAndroid Build Coastguard Worker for trans in [True, False]: 8478*da0073e9SAndroid Build Coastguard Worker c = torch.ones(300) * -300 8479*da0073e9SAndroid Build Coastguard Worker a = torch.ones(300, 300) 8480*da0073e9SAndroid Build Coastguard Worker if trans: 8481*da0073e9SAndroid Build Coastguard Worker a = a.transpose(0, 1).contiguous().transpose(0, 1) 8482*da0073e9SAndroid Build Coastguard Worker b = torch.ones(300) 8483*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.mv, dtype, a, b) 8484*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.addmv, dtype, c, a, b) 8485*da0073e9SAndroid Build Coastguard Worker # test dot 8486*da0073e9SAndroid Build Coastguard Worker a = torch.ones(300) 8487*da0073e9SAndroid Build Coastguard Worker b = torch.ones(300) 8488*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.dot, torch.bfloat16, a, b) 8489*da0073e9SAndroid Build Coastguard Worker check_correctness(torch.dot, torch.half, a, b) 8490*da0073e9SAndroid Build Coastguard Worker 8491*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.half, torch.bfloat16) 8492*da0073e9SAndroid Build Coastguard Worker @parametrize("transpose_a", [True, False]) 8493*da0073e9SAndroid Build Coastguard Worker @parametrize("transpose_b", [True, False]) 8494*da0073e9SAndroid Build Coastguard Worker @parametrize("alpha", [0.0, 0.2, 1.0]) 8495*da0073e9SAndroid Build Coastguard Worker @parametrize("beta", [0.0, 0.5, 1.0]) 8496*da0073e9SAndroid Build Coastguard Worker def test_addmm_mv(self, device, dtype, transpose_a, transpose_b, alpha, beta): 8497*da0073e9SAndroid Build Coastguard Worker def gen_mat(w, h, use_transpose: bool = False): 8498*da0073e9SAndroid Build Coastguard Worker if not use_transpose: 8499*da0073e9SAndroid Build Coastguard Worker return torch.rand(w, h, dtype=dtype, device=device) 8500*da0073e9SAndroid Build Coastguard Worker return torch.rand(h, w, dtype=dtype, device=device).t() 8501*da0073e9SAndroid Build Coastguard Worker # Regression tests for https://github.com/pytorch/pytorch/issues/136299 8502*da0073e9SAndroid Build Coastguard Worker # Should only expose problems on aarch64, but let's be thorough 8503*da0073e9SAndroid Build Coastguard Worker m, n , k = 1, 8, 32 8504*da0073e9SAndroid Build Coastguard Worker A = gen_mat(m, k, transpose_a) 8505*da0073e9SAndroid Build Coastguard Worker B = gen_mat(k, n, transpose_b) 8506*da0073e9SAndroid Build Coastguard Worker C = torch.ones(m, n, dtype=dtype, device=device) 8507*da0073e9SAndroid Build Coastguard Worker rc = torch.addmm(C, A, B, alpha=alpha, beta=beta) 8508*da0073e9SAndroid Build Coastguard Worker ref = alpha * A @ B + beta * C 8509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rc, ref) 8510*da0073e9SAndroid Build Coastguard Worker 8511*da0073e9SAndroid Build Coastguard Worker 8512*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 8513*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-4}) 8514*da0073e9SAndroid Build Coastguard Worker def test_1_sized_with_0_strided(self, device, dtype): 8515*da0073e9SAndroid Build Coastguard Worker a = make_tensor((8, 1, 64), dtype=dtype, device=device) 8516*da0073e9SAndroid Build Coastguard Worker a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1]) 8517*da0073e9SAndroid Build Coastguard Worker b = make_tensor((8, 64, 512), dtype=dtype, device=device) 8518*da0073e9SAndroid Build Coastguard Worker b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512]) 8519*da0073e9SAndroid Build Coastguard Worker res = torch.bmm(a_strided, b_strided) 8520*da0073e9SAndroid Build Coastguard Worker expect = torch.from_numpy( 8521*da0073e9SAndroid Build Coastguard Worker a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(device=device, dtype=dtype) 8522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res) 8523*da0073e9SAndroid Build Coastguard Worker 8524*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestLinalg, globals()) 8525*da0073e9SAndroid Build Coastguard Worker 8526*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 8527*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 8528*da0073e9SAndroid Build Coastguard Worker run_tests() 8529