1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport re 6*da0073e9SAndroid Build Coastguard Workerimport subprocess 7*da0073e9SAndroid Build Coastguard Workerimport sys 8*da0073e9SAndroid Build Coastguard Workerimport unittest 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.stateless as stateless 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \ 14*da0073e9SAndroid Build Coastguard Worker subtest 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerclass MockModule(torch.nn.Module): 18*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 19*da0073e9SAndroid Build Coastguard Worker super().__init__() 20*da0073e9SAndroid Build Coastguard Worker self.l1 = torch.nn.Linear(1, 1) 21*da0073e9SAndroid Build Coastguard Worker self.buffer = torch.nn.Buffer(torch.ones(1)) 22*da0073e9SAndroid Build Coastguard Worker self.foo = 0.0 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 25*da0073e9SAndroid Build Coastguard Worker return self.l1(x) + self.buffer 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerclass MockTiedModule(torch.nn.Module): 29*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 30*da0073e9SAndroid Build Coastguard Worker super().__init__() 31*da0073e9SAndroid Build Coastguard Worker self.l1 = torch.nn.Linear(1, 1) 32*da0073e9SAndroid Build Coastguard Worker self.tied_bias = self.l1.bias 33*da0073e9SAndroid Build Coastguard Worker self.buffer = torch.nn.Buffer(torch.ones(1)) 34*da0073e9SAndroid Build Coastguard Worker self.tied_buffer = self.buffer 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 37*da0073e9SAndroid Build Coastguard Worker return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerclass TestStatelessFunctionalAPI(TestCase): 41*da0073e9SAndroid Build Coastguard Worker def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''): 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)).to(device) 44*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[1.0]], device=device) 45*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([0.0], device=device) 46*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([0.0], device=device) 47*da0073e9SAndroid Build Coastguard Worker if prefix != '': 48*da0073e9SAndroid Build Coastguard Worker parameters = {f'{prefix}.l1.weight': weight, 49*da0073e9SAndroid Build Coastguard Worker f'{prefix}.l1.bias': bias, 50*da0073e9SAndroid Build Coastguard Worker f'{prefix}.buffer': buffer} 51*da0073e9SAndroid Build Coastguard Worker else: 52*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 53*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 54*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 55*da0073e9SAndroid Build Coastguard Worker to_check = module 56*da0073e9SAndroid Build Coastguard Worker if prefix != '': 57*da0073e9SAndroid Build Coastguard Worker to_check = getattr(module, prefix) 58*da0073e9SAndroid Build Coastguard Worker prev_weight = to_check.l1.weight.clone() 59*da0073e9SAndroid Build Coastguard Worker prev_buffer = to_check.buffer.clone() 60*da0073e9SAndroid Build Coastguard Worker # the parameters represent an identity function contrary to the 61*da0073e9SAndroid Build Coastguard Worker # existing params in module. So here we expect the result to be the 62*da0073e9SAndroid Build Coastguard Worker # same as the input if the weight swapping went well. 63*da0073e9SAndroid Build Coastguard Worker res = functional_call(module, parameters, x) 64*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, res) 65*da0073e9SAndroid Build Coastguard Worker # check that the weight remain unmodified 66*da0073e9SAndroid Build Coastguard Worker cur_weight = to_check.l1.weight 67*da0073e9SAndroid Build Coastguard Worker cur_buffer = to_check.buffer 68*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cur_weight, prev_weight) 69*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cur_buffer, prev_buffer) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 72*da0073e9SAndroid Build Coastguard Worker def _ensure_module_unchanged(self, module, message): 73*da0073e9SAndroid Build Coastguard Worker orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers()) 74*da0073e9SAndroid Build Coastguard Worker orig_tensors = orig_parameters + orig_buffers 75*da0073e9SAndroid Build Coastguard Worker orig_tensors_values = tuple(t.clone() for t in orig_tensors) 76*da0073e9SAndroid Build Coastguard Worker try: 77*da0073e9SAndroid Build Coastguard Worker yield module 78*da0073e9SAndroid Build Coastguard Worker finally: 79*da0073e9SAndroid Build Coastguard Worker parameters, buffers = tuple(module.parameters()), tuple(module.buffers()) 80*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 81*da0073e9SAndroid Build Coastguard Worker len(parameters) == len(orig_parameters) 82*da0073e9SAndroid Build Coastguard Worker and len(buffers) == len(orig_buffers) 83*da0073e9SAndroid Build Coastguard Worker and all( 84*da0073e9SAndroid Build Coastguard Worker t1 is t2 and torch.allclose(t1, t3) 85*da0073e9SAndroid Build Coastguard Worker for t1, t2, t3 in zip( 86*da0073e9SAndroid Build Coastguard Worker orig_tensors, 87*da0073e9SAndroid Build Coastguard Worker parameters + buffers, 88*da0073e9SAndroid Build Coastguard Worker orig_tensors_values, 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker ), 91*da0073e9SAndroid Build Coastguard Worker message, 92*da0073e9SAndroid Build Coastguard Worker ) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 95*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 96*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 97*da0073e9SAndroid Build Coastguard Worker ]) 98*da0073e9SAndroid Build Coastguard Worker def test_functional_call(self, functional_call): 99*da0073e9SAndroid Build Coastguard Worker module = MockModule() 100*da0073e9SAndroid Build Coastguard Worker self._run_call_with_mock_module(module, functional_call) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 103*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 104*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 105*da0073e9SAndroid Build Coastguard Worker ]) 106*da0073e9SAndroid Build Coastguard Worker def test_functional_call_with_jit(self, functional_call): 107*da0073e9SAndroid Build Coastguard Worker module = MockModule() 108*da0073e9SAndroid Build Coastguard Worker jit_module = torch.jit.script(module) 109*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 110*da0073e9SAndroid Build Coastguard Worker RuntimeError, 111*da0073e9SAndroid Build Coastguard Worker r'used with Jitted modules' 112*da0073e9SAndroid Build Coastguard Worker ): 113*da0073e9SAndroid Build Coastguard Worker self._run_call_with_mock_module(jit_module, functional_call) 114*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)) 115*da0073e9SAndroid Build Coastguard Worker traced_module = torch.jit.trace(module, x) 116*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 117*da0073e9SAndroid Build Coastguard Worker RuntimeError, 118*da0073e9SAndroid Build Coastguard Worker r'used with Jitted modules' 119*da0073e9SAndroid Build Coastguard Worker ): 120*da0073e9SAndroid Build Coastguard Worker self._run_call_with_mock_module(traced_module, functional_call) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported') 123*da0073e9SAndroid Build Coastguard Worker @unittest.skip("This doesn't work right now") 124*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 125*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 126*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 127*da0073e9SAndroid Build Coastguard Worker ]) 128*da0073e9SAndroid Build Coastguard Worker def test_functional_call_with_data_parallel(self, functional_call): 129*da0073e9SAndroid Build Coastguard Worker module = MockModule() 130*da0073e9SAndroid Build Coastguard Worker module.cuda() 131*da0073e9SAndroid Build Coastguard Worker dp_module = torch.nn.DataParallel(module, [0, 1]) 132*da0073e9SAndroid Build Coastguard Worker self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module') 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported') 135*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 136*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 137*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 138*da0073e9SAndroid Build Coastguard Worker ]) 139*da0073e9SAndroid Build Coastguard Worker def test_functional_call_with_data_parallel_error(self, functional_call): 140*da0073e9SAndroid Build Coastguard Worker module = MockModule() 141*da0073e9SAndroid Build Coastguard Worker module.cuda() 142*da0073e9SAndroid Build Coastguard Worker dp_module = torch.nn.DataParallel(module, [0, 1]) 143*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'): 144*da0073e9SAndroid Build Coastguard Worker functional_call( 145*da0073e9SAndroid Build Coastguard Worker dp_module, 146*da0073e9SAndroid Build Coastguard Worker {'module.weight': torch.zeros(5, device='cuda')}, 147*da0073e9SAndroid Build Coastguard Worker (torch.ones(2, 5, device='cuda'),)) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 150*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 151*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 152*da0073e9SAndroid Build Coastguard Worker ]) 153*da0073e9SAndroid Build Coastguard Worker def test_functional_call_with_gradient(self, functional_call): 154*da0073e9SAndroid Build Coastguard Worker module = MockModule() 155*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)) 156*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[1.0]], requires_grad=True) 157*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([0.0], requires_grad=True) 158*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([0.0]) 159*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 160*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 161*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 162*da0073e9SAndroid Build Coastguard Worker res = functional_call(module, parameters, x) 163*da0073e9SAndroid Build Coastguard Worker # Check that a backward step calculates the gradient of the supplied parameters 164*da0073e9SAndroid Build Coastguard Worker res.backward() 165*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(weight.grad) 166*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(bias.grad) 167*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(buffer.grad) 168*da0073e9SAndroid Build Coastguard Worker # Gradient was not calculated for the module stated and buffers 169*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.l1.weight.grad) 170*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.l1.bias.grad) 171*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.buffer.grad) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 174*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 175*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 176*da0073e9SAndroid Build Coastguard Worker ]) 177*da0073e9SAndroid Build Coastguard Worker def test_functional_batch_norm(self, functional_call): 178*da0073e9SAndroid Build Coastguard Worker module = torch.nn.BatchNorm1d(10) 179*da0073e9SAndroid Build Coastguard Worker module.train() # Allow stats update 180*da0073e9SAndroid Build Coastguard Worker # lets replace the running_mean buffer and check if its correctly updated 181*da0073e9SAndroid Build Coastguard Worker x = torch.full((20, 10), 128.0) 182*da0073e9SAndroid Build Coastguard Worker rm = torch.zeros(10) 183*da0073e9SAndroid Build Coastguard Worker parameters = {'running_mean': rm} 184*da0073e9SAndroid Build Coastguard Worker prev_rm = module.running_mean.clone() 185*da0073e9SAndroid Build Coastguard Worker res = functional_call(module, parameters, x) 186*da0073e9SAndroid Build Coastguard Worker cur_rm = module.running_mean 187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cur_rm, prev_rm) 188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rm, torch.full((10,), 12.8)) 189*da0073e9SAndroid Build Coastguard Worker # Now run functional without reparametrization and check that the module has 190*da0073e9SAndroid Build Coastguard Worker # been updated 191*da0073e9SAndroid Build Coastguard Worker res = functional_call(module, {}, x) 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_mean, torch.full((10,), 12.8)) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 195*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 196*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 197*da0073e9SAndroid Build Coastguard Worker ]) 198*da0073e9SAndroid Build Coastguard Worker def test_circular_references(self, functional_call): 199*da0073e9SAndroid Build Coastguard Worker module = MockModule() 200*da0073e9SAndroid Build Coastguard Worker # Add a circular reference 201*da0073e9SAndroid Build Coastguard Worker module.l1.m = module 202*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)) 203*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[1.0]]) 204*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([0.0]) 205*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([0.0]) 206*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.m.l1.weight': weight, 207*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 208*da0073e9SAndroid Build Coastguard Worker 'l1.m.buffer': buffer} 209*da0073e9SAndroid Build Coastguard Worker prev_weight = module.l1.weight.clone() 210*da0073e9SAndroid Build Coastguard Worker prev_buffer = module.buffer.clone() 211*da0073e9SAndroid Build Coastguard Worker res = functional_call(module, parameters, x, tie_weights=False) 212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, res) 213*da0073e9SAndroid Build Coastguard Worker # check that the weights remain unmodified and were correctly accesed 214*da0073e9SAndroid Build Coastguard Worker cur_weight = module.l1.weight 215*da0073e9SAndroid Build Coastguard Worker cur_buffer = module.buffer 216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cur_weight, prev_weight) 217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cur_buffer, prev_buffer) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 220*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 221*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 222*da0073e9SAndroid Build Coastguard Worker ]) 223*da0073e9SAndroid Build Coastguard Worker def test_reparametrized_module_change_parametrization_original(self, functional_call): 224*da0073e9SAndroid Build Coastguard Worker module = MockModule() 225*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.spectral_norm(module.l1) 226*da0073e9SAndroid Build Coastguard Worker self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 227*da0073e9SAndroid Build Coastguard Worker orig_sn_weight = module.l1.weight.clone() 228*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)) 229*da0073e9SAndroid Build Coastguard Worker # We substitute the parameter inside the parametrization 230*da0073e9SAndroid Build Coastguard Worker # the parametrization itself is not overwritten so it will be applied with a different 231*da0073e9SAndroid Build Coastguard Worker # value for the original tensor 232*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), 233*da0073e9SAndroid Build Coastguard Worker 'l1.bias': torch.tensor([0.0]), 234*da0073e9SAndroid Build Coastguard Worker 'buffer': torch.tensor([0.0])} 235*da0073e9SAndroid Build Coastguard Worker res = functional_call(module, parameters, x) 236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, res) 237*da0073e9SAndroid Build Coastguard Worker # verify that the spectral normalization is still applied 238*da0073e9SAndroid Build Coastguard Worker self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_sn_weight, module.l1.weight) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 242*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 243*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 244*da0073e9SAndroid Build Coastguard Worker ]) 245*da0073e9SAndroid Build Coastguard Worker def test_reparametrize_module_fail_reset_to_original(self, functional_call): 246*da0073e9SAndroid Build Coastguard Worker module = MockModule() 247*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrizations.spectral_norm(module.l1) 248*da0073e9SAndroid Build Coastguard Worker self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 249*da0073e9SAndroid Build Coastguard Worker orig_sn_weight = module.l1.weight.clone() 250*da0073e9SAndroid Build Coastguard Worker # We substitute the parameter inside the parametrization 251*da0073e9SAndroid Build Coastguard Worker # the parametrization itself is not overwritten so it will be applied with a different 252*da0073e9SAndroid Build Coastguard Worker # value for the original tensor 253*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), 254*da0073e9SAndroid Build Coastguard Worker 'l1.bias': torch.tensor([0.0]), 255*da0073e9SAndroid Build Coastguard Worker 'buffer': torch.tensor([0.0])} 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"): 258*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 259*da0073e9SAndroid Build Coastguard Worker def _error_case(): 260*da0073e9SAndroid Build Coastguard Worker x = torch.rand((4, 5)) # to work, it should be of size (1, 1) 261*da0073e9SAndroid Build Coastguard Worker functional_call(module, parameters, x) # this call will fail because x is the wrong size 262*da0073e9SAndroid Build Coastguard Worker _error_case() 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker # verify that the spectral normalization is still applied 265*da0073e9SAndroid Build Coastguard Worker self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) 266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_sn_weight, module.l1.weight) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 269*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 270*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 271*da0073e9SAndroid Build Coastguard Worker ]) 272*da0073e9SAndroid Build Coastguard Worker def test_reparametrize_some_weights(self, functional_call): 273*da0073e9SAndroid Build Coastguard Worker module = MockModule() 274*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[2.0]]) 275*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([5.0]) 276*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([3.0]) 277*da0073e9SAndroid Build Coastguard Worker extra = torch.tensor([1.0]) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight} 280*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 281*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x) 282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 285*da0073e9SAndroid Build Coastguard Worker 'extra': extra} 286*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 287*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x) 288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 291*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 292*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 293*da0073e9SAndroid Build Coastguard Worker ]) 294*da0073e9SAndroid Build Coastguard Worker def test_reparametrize_strict(self, functional_call): 295*da0073e9SAndroid Build Coastguard Worker module = MockModule() 296*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[2.0]]) 297*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([5.0]) 298*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([3.0]) 299*da0073e9SAndroid Build Coastguard Worker extra = torch.tensor([1.0]) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker # All weights no error 302*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 303*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 304*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 305*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 306*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 307*da0073e9SAndroid Build Coastguard Worker module, 308*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a successful call', 309*da0073e9SAndroid Build Coastguard Worker ): 310*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, strict=True) 311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + bias + buffer) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker # Some weights 314*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight} 315*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 316*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 317*da0073e9SAndroid Build Coastguard Worker module, 318*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 319*da0073e9SAndroid Build Coastguard Worker ): 320*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 321*da0073e9SAndroid Build Coastguard Worker RuntimeError, 322*da0073e9SAndroid Build Coastguard Worker re.escape("Missing key(s): 'buffer', 'l1.bias'."), 323*da0073e9SAndroid Build Coastguard Worker ): 324*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, strict=True) 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker # Extra keys 327*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 328*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 329*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer, 330*da0073e9SAndroid Build Coastguard Worker 'extra': extra} 331*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 332*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 333*da0073e9SAndroid Build Coastguard Worker module, 334*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 335*da0073e9SAndroid Build Coastguard Worker ): 336*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 337*da0073e9SAndroid Build Coastguard Worker RuntimeError, 338*da0073e9SAndroid Build Coastguard Worker re.escape("Unexpected key(s): 'extra'."), 339*da0073e9SAndroid Build Coastguard Worker ): 340*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, strict=True) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # Some weights with extra keys 343*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 344*da0073e9SAndroid Build Coastguard Worker 'extra': extra} 345*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 346*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 347*da0073e9SAndroid Build Coastguard Worker module, 348*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 349*da0073e9SAndroid Build Coastguard Worker ): 350*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 351*da0073e9SAndroid Build Coastguard Worker RuntimeError, 352*da0073e9SAndroid Build Coastguard Worker re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."), 353*da0073e9SAndroid Build Coastguard Worker ): 354*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, strict=True) 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 357*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 358*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 359*da0073e9SAndroid Build Coastguard Worker ]) 360*da0073e9SAndroid Build Coastguard Worker def test_reparametrize_special(self, functional_call): 361*da0073e9SAndroid Build Coastguard Worker class NonTensor: 362*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 363*da0073e9SAndroid Build Coastguard Worker return f'<{self.__class__.__name__}>' 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker module = MockModule() 366*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[2.0]]) 367*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([5.0]) 368*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([3.0]) 369*da0073e9SAndroid Build Coastguard Worker non_tensor = NonTensor() 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker # Set to None 372*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 373*da0073e9SAndroid Build Coastguard Worker 'l1.bias': None, 374*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 375*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 376*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 377*da0073e9SAndroid Build Coastguard Worker module, 378*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a successful call', 379*da0073e9SAndroid Build Coastguard Worker ): 380*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x) 381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + buffer) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker # Set non-tensor 384*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': non_tensor} 385*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 386*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 387*da0073e9SAndroid Build Coastguard Worker module, 388*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 389*da0073e9SAndroid Build Coastguard Worker ): 390*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 391*da0073e9SAndroid Build Coastguard Worker TypeError, 392*da0073e9SAndroid Build Coastguard Worker re.escape("<NonTensor> is not an instance of torch.Tensor"), 393*da0073e9SAndroid Build Coastguard Worker ): 394*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker # Set non-tensor attribute 397*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])} 398*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 399*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 400*da0073e9SAndroid Build Coastguard Worker module, 401*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 402*da0073e9SAndroid Build Coastguard Worker ): 403*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 404*da0073e9SAndroid Build Coastguard Worker TypeError, 405*da0073e9SAndroid Build Coastguard Worker re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"), 406*da0073e9SAndroid Build Coastguard Worker ): 407*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker # Set non-exist submodule 410*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 411*da0073e9SAndroid Build Coastguard Worker 'l2.bias': bias} 412*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 413*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 414*da0073e9SAndroid Build Coastguard Worker module, 415*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 416*da0073e9SAndroid Build Coastguard Worker ): 417*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 418*da0073e9SAndroid Build Coastguard Worker AttributeError, 419*da0073e9SAndroid Build Coastguard Worker re.escape("MockModule has no attribute `l2`"), 420*da0073e9SAndroid Build Coastguard Worker ): 421*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 424*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 425*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 426*da0073e9SAndroid Build Coastguard Worker ]) 427*da0073e9SAndroid Build Coastguard Worker def test_tied_weights_warns(self, functional_call): 428*da0073e9SAndroid Build Coastguard Worker module = MockModule() 429*da0073e9SAndroid Build Coastguard Worker module.tied_bias = module.l1.bias 430*da0073e9SAndroid Build Coastguard Worker module.tied_buffer = torch.nn.Buffer(module.buffer) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 433*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 434*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 435*da0073e9SAndroid Build Coastguard Worker ]) 436*da0073e9SAndroid Build Coastguard Worker def test_reparametrize_tie_weights(self, functional_call): 437*da0073e9SAndroid Build Coastguard Worker module = MockTiedModule() 438*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[2.0]]) 439*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([5.0]) 440*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([3.0]) 441*da0073e9SAndroid Build Coastguard Worker extra = torch.tensor([1.0]) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 444*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 445*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 446*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 447*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, tie_weights=True) 448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + bias + bias + buffer + buffer) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 451*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 452*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer, 453*da0073e9SAndroid Build Coastguard Worker 'extra': extra} 454*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 455*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, tie_weights=True) 456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + bias + bias + buffer + buffer) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 459*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 460*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 461*da0073e9SAndroid Build Coastguard Worker ]) 462*da0073e9SAndroid Build Coastguard Worker def test_reparametrize_tie_some_weights(self, functional_call): 463*da0073e9SAndroid Build Coastguard Worker module = MockTiedModule() 464*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[2.0]]) 465*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([3.0]) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 468*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 469*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 470*da0073e9SAndroid Build Coastguard Worker out = stateless.functional_call(module, parameters, x, tie_weights=True) 471*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 474*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 475*da0073e9SAndroid Build Coastguard Worker subtest(stateless._functional_call, "stateless") 476*da0073e9SAndroid Build Coastguard Worker ]) 477*da0073e9SAndroid Build Coastguard Worker def test_tied_weights_errors(self, functional_call): 478*da0073e9SAndroid Build Coastguard Worker module = MockTiedModule() 479*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[1.0]]) 480*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([0.0]) 481*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([0.0]) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 484*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 485*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 486*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 487*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True)) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker # if tied values are the same tensors, shouldn't warn 490*da0073e9SAndroid Build Coastguard Worker parameters['tied_bias'] = bias 491*da0073e9SAndroid Build Coastguard Worker parameters['tied_buffer'] = buffer 492*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True)) 493*da0073e9SAndroid Build Coastguard Worker del parameters['tied_bias'] 494*da0073e9SAndroid Build Coastguard Worker del parameters['tied_buffer'] 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 497*da0073e9SAndroid Build Coastguard Worker ValueError, 498*da0073e9SAndroid Build Coastguard Worker re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"), 499*da0073e9SAndroid Build Coastguard Worker ): 500*da0073e9SAndroid Build Coastguard Worker parameters['tied_bias'] = torch.tensor([5.0]) 501*da0073e9SAndroid Build Coastguard Worker functional_call(module, parameters, x, tie_weights=True) 502*da0073e9SAndroid Build Coastguard Worker del parameters['tied_bias'] 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 505*da0073e9SAndroid Build Coastguard Worker ValueError, 506*da0073e9SAndroid Build Coastguard Worker re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"), 507*da0073e9SAndroid Build Coastguard Worker ): 508*da0073e9SAndroid Build Coastguard Worker parameters['tied_buffer'] = torch.tensor([5.0]) 509*da0073e9SAndroid Build Coastguard Worker functional_call(module, parameters, x, tie_weights=True) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker def test_tied_weights_no_error_without_flag(self): 512*da0073e9SAndroid Build Coastguard Worker module = MockTiedModule() 513*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[1.0]]) 514*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([0.0]) 515*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([0.0]) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 518*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 519*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 520*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 521*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False)) 522*da0073e9SAndroid Build Coastguard Worker parameters['tied_bias'] = torch.tensor([5.0]) 523*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False)) 524*da0073e9SAndroid Build Coastguard Worker del parameters['tied_bias'] 525*da0073e9SAndroid Build Coastguard Worker parameters['tied_buffer'] = torch.tensor([5.0]) 526*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False)) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 529*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 530*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 531*da0073e9SAndroid Build Coastguard Worker ]) 532*da0073e9SAndroid Build Coastguard Worker def test_reparametrize_tie_weights_strict(self, functional_call): 533*da0073e9SAndroid Build Coastguard Worker module = MockTiedModule() 534*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[2.0]]) 535*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([5.0]) 536*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([3.0]) 537*da0073e9SAndroid Build Coastguard Worker extra = torch.tensor([1.0]) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # Tie weights no error 540*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 541*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 542*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 543*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 544*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 545*da0073e9SAndroid Build Coastguard Worker module, 546*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a successful call', 547*da0073e9SAndroid Build Coastguard Worker ): 548*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, tie_weights=True, strict=True) 549*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + bias + bias + buffer + buffer) 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker # Tie weights without flag 552*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 553*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 554*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 555*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 556*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 557*da0073e9SAndroid Build Coastguard Worker module, 558*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 559*da0073e9SAndroid Build Coastguard Worker ): 560*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 561*da0073e9SAndroid Build Coastguard Worker RuntimeError, 562*da0073e9SAndroid Build Coastguard Worker re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."), 563*da0073e9SAndroid Build Coastguard Worker ): 564*da0073e9SAndroid Build Coastguard Worker out = functional_call(module, parameters, x, tie_weights=False, strict=True) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker # Tie some weights 567*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 568*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 569*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 570*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 571*da0073e9SAndroid Build Coastguard Worker module, 572*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 573*da0073e9SAndroid Build Coastguard Worker ): 574*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 575*da0073e9SAndroid Build Coastguard Worker RuntimeError, 576*da0073e9SAndroid Build Coastguard Worker re.escape("Missing key(s): 'l1.bias', 'tied_bias'."), 577*da0073e9SAndroid Build Coastguard Worker ): 578*da0073e9SAndroid Build Coastguard Worker out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker # Tie weights with extra keys 581*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 582*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 583*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer, 584*da0073e9SAndroid Build Coastguard Worker 'extra': extra} 585*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 586*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 587*da0073e9SAndroid Build Coastguard Worker module, 588*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 589*da0073e9SAndroid Build Coastguard Worker ): 590*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 591*da0073e9SAndroid Build Coastguard Worker RuntimeError, 592*da0073e9SAndroid Build Coastguard Worker re.escape("Unexpected key(s): 'extra'."), 593*da0073e9SAndroid Build Coastguard Worker ): 594*da0073e9SAndroid Build Coastguard Worker out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker # Tie weights with extra keys and without flag 597*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 598*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 599*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer, 600*da0073e9SAndroid Build Coastguard Worker 'extra': extra} 601*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 602*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 603*da0073e9SAndroid Build Coastguard Worker module, 604*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 605*da0073e9SAndroid Build Coastguard Worker ): 606*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 607*da0073e9SAndroid Build Coastguard Worker RuntimeError, 608*da0073e9SAndroid Build Coastguard Worker re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."), 609*da0073e9SAndroid Build Coastguard Worker ): 610*da0073e9SAndroid Build Coastguard Worker out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker # Tie some weights with extra keys 613*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 614*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer, 615*da0073e9SAndroid Build Coastguard Worker 'extra': extra} 616*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 617*da0073e9SAndroid Build Coastguard Worker with self._ensure_module_unchanged( 618*da0073e9SAndroid Build Coastguard Worker module, 619*da0073e9SAndroid Build Coastguard Worker 'the module should not have been modified by a failed call', 620*da0073e9SAndroid Build Coastguard Worker ): 621*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 622*da0073e9SAndroid Build Coastguard Worker RuntimeError, 623*da0073e9SAndroid Build Coastguard Worker re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."), 624*da0073e9SAndroid Build Coastguard Worker ): 625*da0073e9SAndroid Build Coastguard Worker out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 628*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 629*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 630*da0073e9SAndroid Build Coastguard Worker ]) 631*da0073e9SAndroid Build Coastguard Worker def test_setattr(self, functional_call): 632*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 633*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 634*da0073e9SAndroid Build Coastguard Worker super().__init__() 635*da0073e9SAndroid Build Coastguard Worker self.foo = torch.nn.Buffer(torch.tensor([0.0])) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 638*da0073e9SAndroid Build Coastguard Worker self.foo = self.foo + 1 639*da0073e9SAndroid Build Coastguard Worker return x + self.foo 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker foo = torch.tensor([2.0]) 642*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 643*da0073e9SAndroid Build Coastguard Worker a = {'foo': foo} 644*da0073e9SAndroid Build Coastguard Worker mod = Foo() 645*da0073e9SAndroid Build Coastguard Worker functional_call(mod, a, x) 646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.foo, torch.tensor([0.0])) 647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a['foo'], torch.tensor([3.0])) 648*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo, torch.tensor([2.0])) 649*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a['foo'] is not foo) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 652*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 653*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 654*da0073e9SAndroid Build Coastguard Worker ]) 655*da0073e9SAndroid Build Coastguard Worker def test_in_place_operator(self, functional_call): 656*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 657*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 658*da0073e9SAndroid Build Coastguard Worker super().__init__() 659*da0073e9SAndroid Build Coastguard Worker self.foo = torch.nn.Buffer(torch.tensor([0.0])) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 662*da0073e9SAndroid Build Coastguard Worker self.foo.add_(1) 663*da0073e9SAndroid Build Coastguard Worker return x + self.foo 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker foo = torch.tensor([2.0]) 666*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 667*da0073e9SAndroid Build Coastguard Worker a = {'foo': foo} 668*da0073e9SAndroid Build Coastguard Worker mod = Foo() 669*da0073e9SAndroid Build Coastguard Worker functional_call(mod, a, x) 670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.foo, torch.tensor([0.0])) 671*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a['foo'], torch.tensor([3.0])) 672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo, torch.tensor([3.0])) 673*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a['foo'] is foo) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 676*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 677*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 678*da0073e9SAndroid Build Coastguard Worker ]) 679*da0073e9SAndroid Build Coastguard Worker def test_setattr_strict(self, functional_call): 680*da0073e9SAndroid Build Coastguard Worker class Bar(torch.nn.Module): 681*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 682*da0073e9SAndroid Build Coastguard Worker super().__init__() 683*da0073e9SAndroid Build Coastguard Worker assert not hasattr(self, 'extra') 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 686*da0073e9SAndroid Build Coastguard Worker return x + self.extra 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker a = {'extra': torch.zeros(())} 689*da0073e9SAndroid Build Coastguard Worker mod = Bar() 690*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(mod, 'extra')) 691*da0073e9SAndroid Build Coastguard Worker out = functional_call(mod, a, torch.ones(())) 692*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.ones(())) 693*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(mod, 'extra')) 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker a = {'extra': torch.zeros(())} 696*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 697*da0073e9SAndroid Build Coastguard Worker RuntimeError, 698*da0073e9SAndroid Build Coastguard Worker re.escape("Unexpected key(s): 'extra'."), 699*da0073e9SAndroid Build Coastguard Worker ): 700*da0073e9SAndroid Build Coastguard Worker out = functional_call(mod, a, torch.ones(()), strict=True) 701*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(mod, 'extra')) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker a = {} 704*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 705*da0073e9SAndroid Build Coastguard Worker AttributeError, 706*da0073e9SAndroid Build Coastguard Worker re.escape("'Bar' object has no attribute 'extra'"), 707*da0073e9SAndroid Build Coastguard Worker ): 708*da0073e9SAndroid Build Coastguard Worker out = functional_call(mod, a, torch.ones(())) 709*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(mod, 'extra')) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker a = {} 712*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 713*da0073e9SAndroid Build Coastguard Worker AttributeError, 714*da0073e9SAndroid Build Coastguard Worker re.escape("'Bar' object has no attribute 'extra'"), 715*da0073e9SAndroid Build Coastguard Worker ): 716*da0073e9SAndroid Build Coastguard Worker out = functional_call(mod, a, torch.ones(()), strict=True) 717*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(mod, 'extra')) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 720*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 721*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 722*da0073e9SAndroid Build Coastguard Worker ]) 723*da0073e9SAndroid Build Coastguard Worker def test_functional_call_with_kwargs(self, functional_call): 724*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 725*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 726*da0073e9SAndroid Build Coastguard Worker super().__init__() 727*da0073e9SAndroid Build Coastguard Worker self.x = x 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker def forward(self, inp, *, other_inp): 730*da0073e9SAndroid Build Coastguard Worker return inp * self.x + other_inp 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker a = {'x': torch.zeros(2, 3)} 733*da0073e9SAndroid Build Coastguard Worker mod = Foo(torch.randn(2, 3)) 734*da0073e9SAndroid Build Coastguard Worker inp, other_inp = torch.randn(2, 3), torch.randn(2, 3) 735*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "missing 1 required keyword-only argument: 'other_inp'"): 736*da0073e9SAndroid Build Coastguard Worker functional_call(mod, a, inp) 737*da0073e9SAndroid Build Coastguard Worker res = functional_call(mod, a, inp, {'other_inp': other_inp}) 738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, other_inp) 739*da0073e9SAndroid Build Coastguard Worker res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp}) 740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_1) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker def test_functional_call_tuple_dicts(self): 743*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 744*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)) 745*da0073e9SAndroid Build Coastguard Worker parameters = {k: torch.ones_like(v) for k, v in mod.named_parameters()} 746*da0073e9SAndroid Build Coastguard Worker buffers = {k: torch.zeros_like(v) for k, v in mod.named_buffers()} 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker # two dictionaries 749*da0073e9SAndroid Build Coastguard Worker res = torch.func.functional_call(mod, (parameters, buffers), x) 750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, x + 1) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker # no dictionaries 753*da0073e9SAndroid Build Coastguard Worker res = torch.func.functional_call(mod, (), x) 754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, mod(x)) 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker # three dictonaries 757*da0073e9SAndroid Build Coastguard Worker a = ({'l1.weight': torch.ones(1, 1)}, {'l1.bias': torch.ones(1)}, {'buffer': torch.zeros(1)}) 758*da0073e9SAndroid Build Coastguard Worker res = torch.func.functional_call(mod, a, x) 759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, x + 1) 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker def test_functional_call_multiple_dicts_error(self): 762*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 763*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)) 764*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))} 765*da0073e9SAndroid Build Coastguard Worker repeated_parameters = {'l1.weight': torch.ones((1, 1))} 766*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 767*da0073e9SAndroid Build Coastguard Worker ValueError, 768*da0073e9SAndroid Build Coastguard Worker re.escape("['l1.weight'] appeared in multiple dictionaries"), 769*da0073e9SAndroid Build Coastguard Worker ): 770*da0073e9SAndroid Build Coastguard Worker torch.func.functional_call(mod, (parameters, repeated_parameters), x) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker @parametrize("functional_call", [ 773*da0073e9SAndroid Build Coastguard Worker subtest(torch.func.functional_call, "torch_func"), 774*da0073e9SAndroid Build Coastguard Worker subtest(stateless.functional_call, "stateless") 775*da0073e9SAndroid Build Coastguard Worker ]) 776*da0073e9SAndroid Build Coastguard Worker def test_functional_call_member_reference(self, functional_call): 777*da0073e9SAndroid Build Coastguard Worker class Module(torch.nn.Module): 778*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 779*da0073e9SAndroid Build Coastguard Worker super().__init__() 780*da0073e9SAndroid Build Coastguard Worker self.l1 = torch.nn.Linear(1, 1) 781*da0073e9SAndroid Build Coastguard Worker self.buffer = torch.nn.Buffer(torch.ones(1)) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 784*da0073e9SAndroid Build Coastguard Worker parameters = tuple(self.parameters()) 785*da0073e9SAndroid Build Coastguard Worker buffers = tuple(self.buffers()) 786*da0073e9SAndroid Build Coastguard Worker return self.l1(x) + self.buffer, parameters, buffers 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker module = Module() 789*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[2.0]]) 790*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([5.0]) 791*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([3.0]) 792*da0073e9SAndroid Build Coastguard Worker extra = torch.tensor([1.0]) 793*da0073e9SAndroid Build Coastguard Worker extra_p = torch.nn.Parameter(extra) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker # All weights 796*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 797*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 798*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 799*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 800*da0073e9SAndroid Build Coastguard Worker out, parameters, buffers = functional_call(module, parameters, x) 801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + bias + buffer) 802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameters, (weight, bias)) 803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffers, (buffer,)) 804*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias)))) 805*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,)))) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker # Some weights 808*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight} 809*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 810*da0073e9SAndroid Build Coastguard Worker out, parameters, buffers = functional_call(module, parameters, x) 811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameters, (weight, module.l1.bias)) 813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffers, (module.buffer,)) 814*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias)))) 815*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker # All weights with extra keys 818*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 819*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 820*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer, 821*da0073e9SAndroid Build Coastguard Worker 'l1.extra': extra} 822*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 823*da0073e9SAndroid Build Coastguard Worker out, parameters, buffers = functional_call(module, parameters, x) 824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + bias + buffer) 825*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameters, (weight, bias)) 826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffers, (buffer,)) 827*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias)))) 828*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,)))) 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker # All weights with extra keys with parameters 831*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 832*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 833*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer, 834*da0073e9SAndroid Build Coastguard Worker 'l1.extra': extra_p} 835*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 836*da0073e9SAndroid Build Coastguard Worker out, parameters, buffers = functional_call(module, parameters, x) 837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + bias + buffer) 838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameters, (weight, bias, extra_p)) 839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffers, (buffer,)) 840*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p)))) 841*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,)))) 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker # Some weights with extra keys 844*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 845*da0073e9SAndroid Build Coastguard Worker 'l1.extra': extra} 846*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 847*da0073e9SAndroid Build Coastguard Worker out, parameters, buffers = functional_call(module, parameters, x) 848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameters, (weight, module.l1.bias)) 850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffers, (module.buffer)) 851*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias)))) 852*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker # Some weights with extra keys with parameters 855*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 856*da0073e9SAndroid Build Coastguard Worker 'l1.extra': extra_p} 857*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 858*da0073e9SAndroid Build Coastguard Worker out, parameters, buffers = functional_call(module, parameters, x) 859*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + module.l1.bias + module.buffer) 860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameters, (weight, module.l1.bias, extra_p)) 861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffers, (module.buffer)) 862*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p)))) 863*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker # Set None 866*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 867*da0073e9SAndroid Build Coastguard Worker 'l1.bias': None} 868*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 869*da0073e9SAndroid Build Coastguard Worker out, parameters, buffers = functional_call(module, parameters, x) 870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x * weight + module.buffer) 871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameters, (weight,)) 872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffers, (module.buffer)) 873*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,)))) 874*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,)))) 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Workerclass TestStatelessDeprecation(TestCase): 878*da0073e9SAndroid Build Coastguard Worker def test_private_stateless_warns(self): 879*da0073e9SAndroid Build Coastguard Worker script = """ 880*da0073e9SAndroid Build Coastguard Workerimport torch 881*da0073e9SAndroid Build Coastguard Workerimport warnings 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Workerwith warnings.catch_warnings(record=True) as w: 884*da0073e9SAndroid Build Coastguard Worker from torch.nn.utils import _stateless 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Workerexit(len(w)) 887*da0073e9SAndroid Build Coastguard Worker""" 888*da0073e9SAndroid Build Coastguard Worker try: 889*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 890*da0073e9SAndroid Build Coastguard Worker [sys.executable, '-W', 'always', '-c', script], 891*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 892*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 893*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 894*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)),) 895*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.returncode, 1) 897*da0073e9SAndroid Build Coastguard Worker else: 898*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False, "No warning was raised.") 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker def test_stateless_functional_call_warns(self): 901*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Linear(1, 1) 902*da0073e9SAndroid Build Coastguard Worker params = dict(m.named_parameters()) 903*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 1) 904*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"): 905*da0073e9SAndroid Build Coastguard Worker stateless.functional_call(m, params, x) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Workerclass TestPythonOptimizeMode(TestCase): 908*da0073e9SAndroid Build Coastguard Worker def test_runs_with_optimize_flag(self): 909*da0073e9SAndroid Build Coastguard Worker script = "import torch; import torch._functorch.deprecated" 910*da0073e9SAndroid Build Coastguard Worker try: 911*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 912*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-OO", "-c", script], 913*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 914*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 915*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 916*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)),) 917*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 918*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e.returncode, "Import failed while running python in optimized mode") 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests( 922*da0073e9SAndroid Build Coastguard Worker TestStatelessFunctionalAPI, 923*da0073e9SAndroid Build Coastguard Worker) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 926*da0073e9SAndroid Build Coastguard Worker run_tests() 927