# Owner(s): ["module: nn"] import math import random import string import unittest from functools import reduce from operator import mul import torch import torch.nn.functional as F import torch.nn.init as init from torch.testing._internal.common_utils import ( run_tests, skipIfNoLapack, skipIfTorchDynamo, slowTest, TEST_SCIPY, TestCase, ) if TEST_SCIPY: from scipy import stats class TestNNInit(TestCase): def setUp(self): super().setUp() random.seed(123) def _is_normal(self, tensor, mean, std): samples = tensor.view(-1).tolist() p_value = stats.kstest(samples, "norm", args=(mean, std))[1] return p_value > 0.0001 def _is_trunc_normal(self, tensor, mean, std, a, b): # scipy's trunc norm is suited for data drawn from N(0, 1), # so we need to transform our data to test it using scipy. z_samples = (tensor.view(-1) - mean) / std z_samples = z_samples.tolist() a0 = (a - mean) / std b0 = (b - mean) / std p_value = stats.kstest(z_samples, "truncnorm", args=(a0, b0))[1] return p_value > 0.0001 def _is_uniform(self, tensor, a, b): samples = tensor.view(-1).tolist() p_value = stats.kstest(samples, "uniform", args=(a, (b - a)))[1] return p_value > 0.0001 def _create_random_nd_tensor(self, dims, size_min, size_max): size = [random.randint(size_min, size_max) for _ in range(dims)] tensor = torch.zeros(size) return tensor def _random_float(self, a, b): return (b - a) * random.random() + a def test_calculate_gain_linear(self): for fn in [ "linear", "conv1d", "conv2d", "conv3d", "conv_transpose2d", "conv_transpose2d", "conv_transpose3d", ]: gain = init.calculate_gain(fn) self.assertEqual(gain, 1) def test_calculate_gain_nonlinear(self): for fn in ["sigmoid", "tanh", "relu", "leaky_relu"]: gain = init.calculate_gain(fn) if fn == "sigmoid": self.assertEqual(gain, 1) elif fn == "tanh": # 5 / 3 self.assertEqual(gain, 1.6666666666666667) elif fn == "relu": # sqrt(2) self.assertEqual(gain, 1.4142135623730951) elif fn == "leaky_relu": # sqrt(2 / 1 + slope^2)) self.assertEqual(gain, 1.4141428569978354) elif fn == "selu": self.assertEqual(gain, 0.75) def test_calculate_gain_leaky_relu(self): for param in [None, 0, 0.01, 10]: gain = init.calculate_gain("leaky_relu", param) if param is None: # Default slope is 0.01 self.assertEqual(gain, 1.4141428569978354) elif param == 0: # No slope = same gain as normal ReLU self.assertEqual(gain, 1.4142135623730951) elif param == 0.01: self.assertEqual(gain, 1.4141428569978354) elif param == 10: self.assertEqual(gain, 0.14071950894605836) def test_calculate_gain_leaky_relu_only_accepts_numbers(self): for param in [True, [1], {"a": "b"}]: with self.assertRaises(ValueError): init.calculate_gain("leaky_relu", param) def test_calculate_gain_only_accepts_valid_nonlinearities(self): for n in [2, 5, 25]: # Generate random strings of lengths that definitely aren't supported random_string = "".join( [random.choice(string.ascii_lowercase) for i in range(n)] ) with self.assertRaises(ValueError): init.calculate_gain(random_string) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_uniform(self): for dims in [1, 2, 4]: input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) a = self._random_float(-3, 3) b = a + self._random_float(1, 5) init.uniform_(input_tensor, a=a, b=b) assert self._is_uniform(input_tensor, a, b) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_normal(self): for dims in [1, 2, 4]: input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) mean = self._random_float(-3, 3) std = self._random_float(1, 5) init.normal_(input_tensor, mean=mean, std=std) assert self._is_normal(input_tensor, mean, std) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_trunc_normal(self): for dims in [1, 2, 4]: input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) mean = self._random_float(-3, 3) std = self._random_float(0.01, 1) a = self._random_float(mean - 2 * std, mean) b = self._random_float(mean, mean + 2 * std) init.trunc_normal_(input_tensor, mean=mean, std=std, a=a, b=b) assert self._is_trunc_normal(input_tensor, mean, std, a, b) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_trunc_normal_generator(self): gen = torch.Generator() gen.manual_seed(42) input_tensor = torch.empty(5) init.trunc_normal_(input_tensor, generator=gen) ref = torch.empty(5) torch.manual_seed(42) init.trunc_normal_(ref) self.assertEqual(input_tensor, ref) assert self._is_trunc_normal(input_tensor, mean=0, std=1, a=0, b=1) def test_constant(self): for dims in [1, 2, 4]: input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5) val = self._random_float(1, 10) init.constant_(input_tensor, val) self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) def test_ones_and_zeros(self): for init_fn_, val in zip([init.ones_, init.zeros_], [1, 0]): for dims in [1, 2, 4]: input_tensor = self._create_random_nd_tensor( dims, size_min=1, size_max=5 ) init_fn_(input_tensor) self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) def test_eye(self): input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5) init.eye_(input_tensor) # Check every single element for i in range(input_tensor.size(0)): for j in range(input_tensor.size(1)): if i == j: assert input_tensor[i][j] == 1 else: assert input_tensor[i][j] == 0 def test_eye_only_works_on_2d_inputs(self): for dims in [1, 3]: with self.assertRaises(ValueError): tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) init.eye_(tensor) def test_dirac_properties(self): for dims in [3, 4, 5]: for groups in [1, 2, 3]: # prepare random tensor with random sizes, but fits groups a, c, d, e = (random.randint(1, 5) for _ in range(4)) b = random.randint( 1, 5 * groups ) # same range as a*groups but all range allowed # make sure first dim divides by groups input_tensor = torch.randn((a * groups, b, c, d, e)[:dims]) init.dirac_(input_tensor, groups) c_out, c_in = input_tensor.size(0) // groups, input_tensor.size(1) min_d = min(c_out, c_in) # Check number of nonzeros is equivalent to smallest dim (for each group) assert torch.nonzero(input_tensor).size(0) == min_d * groups # Check sum of values (can have precision issues, hence assertEqual) is also equivalent self.assertEqual(input_tensor.sum(), min_d * groups) def test_dirac_identity(self): for groups in [1, 3]: batch, in_c, out_c, size, kernel_size = ( 8, 3, 9, 5, 3, ) # in_c, out_c must divide by groups eff_out_c = out_c // groups # Test 1D input_var = torch.randn(batch, in_c, size) filter_var = torch.zeros(eff_out_c, in_c, kernel_size) filter_var = torch.cat([filter_var] * groups) init.dirac_(filter_var, groups) output_var = F.conv1d(input_var, filter_var) input_tensor, output_tensor = ( input_var.data, output_var.data, ) # Variables do not support nonzero for g in range(groups): # Assert in_c outputs are preserved (per each group) self.assertEqual( input_tensor[:, :, 1:-1], output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :], ) # Assert extra outputs are 0 assert ( torch.nonzero( output_tensor[:, eff_out_c * g + in_c : eff_out_c * (g + 1), :] ).numel() == 0 ) # Test 2D input_var = torch.randn(batch, in_c, size, size) filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size) filter_var = torch.cat([filter_var] * groups) init.dirac_(filter_var, groups) output_var = F.conv2d(input_var, filter_var) input_tensor, output_tensor = ( input_var.data, output_var.data, ) # Variables do not support nonzero for g in range(groups): # Assert in_c outputs are preserved (per each group) self.assertEqual( input_tensor[:, :, 1:-1, 1:-1], output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :], ) # Assert extra outputs are 0 assert ( torch.nonzero( output_tensor[ :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, : ] ).numel() == 0 ) # Test 3D input_var = torch.randn(batch, in_c, size, size, size) filter_var = torch.zeros( eff_out_c, in_c, kernel_size, kernel_size, kernel_size ) filter_var = torch.cat([filter_var] * groups) init.dirac_(filter_var, groups) output_var = F.conv3d(input_var, filter_var) input_tensor, output_tensor = input_var.data, output_var.data for g in range(groups): # Assert in_c outputs are preserved (per each group) self.assertEqual( input_tensor[:, :, 1:-1, 1:-1, 1:-1], output_tensor[:, eff_out_c * g : eff_out_c * g + in_c, :, :, :], ) # Assert extra outputs are 0 assert ( torch.nonzero( output_tensor[ :, eff_out_c * g + in_c : eff_out_c * (g + 1), :, :, : ] ).numel() == 0 ) def test_dirac_only_works_on_3_4_5d_inputs(self): for dims in [1, 2, 6]: with self.assertRaises(ValueError): tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) init.dirac_(tensor) def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self): for dims in [0, 1]: tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) with self.assertRaises(ValueError): init.xavier_uniform_(tensor) def test_xavier_normal_errors_on_inputs_smaller_than_2d(self): for dims in [0, 1]: tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) with self.assertRaises(ValueError): init.xavier_normal_(tensor) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @slowTest def test_xavier_uniform(self): for use_gain in [True, False]: for dims in [2, 4]: input_tensor = self._create_random_nd_tensor( dims, size_min=20, size_max=25 ) gain = 1 if use_gain: gain = self._random_float(0.1, 2) init.xavier_uniform_(input_tensor, gain=gain) else: init.xavier_uniform_(input_tensor) fan_in = input_tensor.size(1) fan_out = input_tensor.size(0) if input_tensor.dim() > 2: fan_in *= input_tensor[0, 0].numel() fan_out *= input_tensor[0, 0].numel() expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) bounds = expected_std * math.sqrt(3) assert self._is_uniform(input_tensor, -bounds, bounds) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_xavier_normal(self): for use_gain in [True, False]: for dims in [2, 4]: input_tensor = self._create_random_nd_tensor( dims, size_min=20, size_max=25 ) gain = 1 if use_gain: gain = self._random_float(0.1, 2) init.xavier_normal_(input_tensor, gain=gain) else: init.xavier_normal_(input_tensor) fan_in = input_tensor.size(1) fan_out = input_tensor.size(0) if input_tensor.dim() > 2: fan_in *= input_tensor[0, 0].numel() fan_out *= input_tensor[0, 0].numel() expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) assert self._is_normal(input_tensor, 0, expected_std) def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self): for dims in [0, 1]: with self.assertRaises(ValueError): tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) init.kaiming_uniform_(tensor) def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self): for dims in [0, 1]: with self.assertRaises(ValueError): tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) init.kaiming_normal_(tensor) def test_kaiming_uniform_warning_on_0element_tensor(self): tensor = torch.empty(0, 1) with self.assertWarnsRegex( UserWarning, "Initializing zero-element tensors is a no-op" ): _ = init.kaiming_uniform_(tensor) def test_kaiming_normal_warning_on_0element_tensor(self): tensor = torch.empty(0, 1) with self.assertWarnsRegex( UserWarning, "Initializing zero-element tensors is a no-op" ): _ = init.kaiming_normal_(tensor) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_kaiming_uniform(self): for use_a in [True, False]: for dims in [2, 4]: for mode in ["fan_in", "fan_out"]: input_tensor = self._create_random_nd_tensor( dims, size_min=20, size_max=25 ) if use_a: a = self._random_float(0.1, 2) init.kaiming_uniform_(input_tensor, a=a, mode=mode) else: a = 0 init.kaiming_uniform_(input_tensor, mode=mode) fan_in = input_tensor.size(1) fan_out = input_tensor.size(0) if input_tensor.dim() > 2: fan_in *= input_tensor[0, 0].numel() fan_out *= input_tensor[0, 0].numel() if mode == "fan_in": n = fan_in else: n = fan_out expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) bounds = expected_std * math.sqrt(3.0) assert self._is_uniform(input_tensor, -bounds, bounds) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_kaiming_normal(self): for use_a in [True, False]: for dims in [2, 4]: for mode in ["fan_in", "fan_out"]: input_tensor = self._create_random_nd_tensor( dims, size_min=20, size_max=25 ) if use_a: a = self._random_float(0.1, 2) init.kaiming_normal_(input_tensor, a=a, mode=mode) else: a = 0 init.kaiming_normal_(input_tensor, mode=mode) fan_in = input_tensor.size(1) fan_out = input_tensor.size(0) if input_tensor.dim() > 2: fan_in *= input_tensor[0, 0].numel() fan_out *= input_tensor[0, 0].numel() if mode == "fan_in": n = fan_in else: n = fan_out expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) assert self._is_normal(input_tensor, 0, expected_std) def test_sparse_only_works_on_2d_inputs(self): for dims in [1, 3]: with self.assertRaises(ValueError): sparsity = self._random_float(0.1, 0.9) tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) init.sparse_(tensor, sparsity) @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") @skipIfTorchDynamo("scipy.kstest is failing under dynamo") def test_sparse_default_std(self): for use_random_std in [True, False]: input_tensor = self._create_random_nd_tensor(2, size_min=30, size_max=35) rows, cols = input_tensor.size(0), input_tensor.size(1) sparsity = self._random_float(0.1, 0.2) std = 0.01 # default std if use_random_std: std = self._random_float(0.01, 0.2) init.sparse_(input_tensor, sparsity=sparsity, std=std) else: init.sparse_(input_tensor, sparsity=sparsity) for col_idx in range(input_tensor.size(1)): column = input_tensor[:, col_idx] assert column[column == 0].nelement() >= math.ceil(sparsity * rows) assert self._is_normal(input_tensor[input_tensor != 0], 0, std) @skipIfNoLapack def test_orthogonal(self): for use_gain in [True, False]: for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]: input_tensor = torch.zeros(tensor_size) gain = 1.0 if use_gain: gain = self._random_float(0.1, 2) init.orthogonal_(input_tensor, gain=gain) else: init.orthogonal_(input_tensor) rows, cols = tensor_size[0], reduce(mul, tensor_size[1:]) flattened_tensor = input_tensor.view(rows, cols) if rows > cols: self.assertEqual( torch.mm(flattened_tensor.t(), flattened_tensor), torch.eye(cols) * gain**2, atol=1e-6, rtol=0, ) else: self.assertEqual( torch.mm(flattened_tensor, flattened_tensor.t()), torch.eye(rows) * gain**2, atol=1e-6, rtol=0, ) def test_deprecation(self): x = torch.randn(3, 3) def fn(): init.normal(x) with self.assertWarnsRegex( FutureWarning, "deprecated", msg="methods not suffixed with underscore should be deprecated", ): fn() if __name__ == "__main__": run_tests()