xref: /aosp_15_r20/external/pytorch/test/test_stateless.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport re
6*da0073e9SAndroid Build Coastguard Workerimport subprocess
7*da0073e9SAndroid Build Coastguard Workerimport sys
8*da0073e9SAndroid Build Coastguard Workerimport unittest
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.stateless as stateless
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \
14*da0073e9SAndroid Build Coastguard Worker    subtest
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerclass MockModule(torch.nn.Module):
18*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
19*da0073e9SAndroid Build Coastguard Worker        super().__init__()
20*da0073e9SAndroid Build Coastguard Worker        self.l1 = torch.nn.Linear(1, 1)
21*da0073e9SAndroid Build Coastguard Worker        self.buffer = torch.nn.Buffer(torch.ones(1))
22*da0073e9SAndroid Build Coastguard Worker        self.foo = 0.0
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
25*da0073e9SAndroid Build Coastguard Worker        return self.l1(x) + self.buffer
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass MockTiedModule(torch.nn.Module):
29*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
30*da0073e9SAndroid Build Coastguard Worker        super().__init__()
31*da0073e9SAndroid Build Coastguard Worker        self.l1 = torch.nn.Linear(1, 1)
32*da0073e9SAndroid Build Coastguard Worker        self.tied_bias = self.l1.bias
33*da0073e9SAndroid Build Coastguard Worker        self.buffer = torch.nn.Buffer(torch.ones(1))
34*da0073e9SAndroid Build Coastguard Worker        self.tied_buffer = self.buffer
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
37*da0073e9SAndroid Build Coastguard Worker        return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerclass TestStatelessFunctionalAPI(TestCase):
41*da0073e9SAndroid Build Coastguard Worker    def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1)).to(device)
44*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[1.0]], device=device)
45*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([0.0], device=device)
46*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([0.0], device=device)
47*da0073e9SAndroid Build Coastguard Worker        if prefix != '':
48*da0073e9SAndroid Build Coastguard Worker            parameters = {f'{prefix}.l1.weight': weight,
49*da0073e9SAndroid Build Coastguard Worker                          f'{prefix}.l1.bias': bias,
50*da0073e9SAndroid Build Coastguard Worker                          f'{prefix}.buffer': buffer}
51*da0073e9SAndroid Build Coastguard Worker        else:
52*da0073e9SAndroid Build Coastguard Worker            parameters = {'l1.weight': weight,
53*da0073e9SAndroid Build Coastguard Worker                          'l1.bias': bias,
54*da0073e9SAndroid Build Coastguard Worker                          'buffer': buffer}
55*da0073e9SAndroid Build Coastguard Worker        to_check = module
56*da0073e9SAndroid Build Coastguard Worker        if prefix != '':
57*da0073e9SAndroid Build Coastguard Worker            to_check = getattr(module, prefix)
58*da0073e9SAndroid Build Coastguard Worker        prev_weight = to_check.l1.weight.clone()
59*da0073e9SAndroid Build Coastguard Worker        prev_buffer = to_check.buffer.clone()
60*da0073e9SAndroid Build Coastguard Worker        # the parameters represent an identity function contrary to the
61*da0073e9SAndroid Build Coastguard Worker        # existing params in module. So here we expect the result to be the
62*da0073e9SAndroid Build Coastguard Worker        # same as the input if the weight swapping went well.
63*da0073e9SAndroid Build Coastguard Worker        res = functional_call(module, parameters, x)
64*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, res)
65*da0073e9SAndroid Build Coastguard Worker        # check that the weight remain unmodified
66*da0073e9SAndroid Build Coastguard Worker        cur_weight = to_check.l1.weight
67*da0073e9SAndroid Build Coastguard Worker        cur_buffer = to_check.buffer
68*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cur_weight, prev_weight)
69*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cur_buffer, prev_buffer)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    @contextlib.contextmanager
72*da0073e9SAndroid Build Coastguard Worker    def _ensure_module_unchanged(self, module, message):
73*da0073e9SAndroid Build Coastguard Worker        orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers())
74*da0073e9SAndroid Build Coastguard Worker        orig_tensors = orig_parameters + orig_buffers
75*da0073e9SAndroid Build Coastguard Worker        orig_tensors_values = tuple(t.clone() for t in orig_tensors)
76*da0073e9SAndroid Build Coastguard Worker        try:
77*da0073e9SAndroid Build Coastguard Worker            yield module
78*da0073e9SAndroid Build Coastguard Worker        finally:
79*da0073e9SAndroid Build Coastguard Worker            parameters, buffers = tuple(module.parameters()), tuple(module.buffers())
80*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
81*da0073e9SAndroid Build Coastguard Worker                len(parameters) == len(orig_parameters)
82*da0073e9SAndroid Build Coastguard Worker                and len(buffers) == len(orig_buffers)
83*da0073e9SAndroid Build Coastguard Worker                and all(
84*da0073e9SAndroid Build Coastguard Worker                    t1 is t2 and torch.allclose(t1, t3)
85*da0073e9SAndroid Build Coastguard Worker                    for t1, t2, t3 in zip(
86*da0073e9SAndroid Build Coastguard Worker                        orig_tensors,
87*da0073e9SAndroid Build Coastguard Worker                        parameters + buffers,
88*da0073e9SAndroid Build Coastguard Worker                        orig_tensors_values,
89*da0073e9SAndroid Build Coastguard Worker                    )
90*da0073e9SAndroid Build Coastguard Worker                ),
91*da0073e9SAndroid Build Coastguard Worker                message,
92*da0073e9SAndroid Build Coastguard Worker            )
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
95*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
96*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
97*da0073e9SAndroid Build Coastguard Worker    ])
98*da0073e9SAndroid Build Coastguard Worker    def test_functional_call(self, functional_call):
99*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
100*da0073e9SAndroid Build Coastguard Worker        self._run_call_with_mock_module(module, functional_call)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
103*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
104*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
105*da0073e9SAndroid Build Coastguard Worker    ])
106*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_with_jit(self, functional_call):
107*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
108*da0073e9SAndroid Build Coastguard Worker        jit_module = torch.jit.script(module)
109*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
110*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
111*da0073e9SAndroid Build Coastguard Worker            r'used with Jitted modules'
112*da0073e9SAndroid Build Coastguard Worker        ):
113*da0073e9SAndroid Build Coastguard Worker            self._run_call_with_mock_module(jit_module, functional_call)
114*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1))
115*da0073e9SAndroid Build Coastguard Worker        traced_module = torch.jit.trace(module, x)
116*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
117*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
118*da0073e9SAndroid Build Coastguard Worker            r'used with Jitted modules'
119*da0073e9SAndroid Build Coastguard Worker        ):
120*da0073e9SAndroid Build Coastguard Worker            self._run_call_with_mock_module(traced_module, functional_call)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
123*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("This doesn't work right now")
124*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
125*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
126*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
127*da0073e9SAndroid Build Coastguard Worker    ])
128*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_with_data_parallel(self, functional_call):
129*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
130*da0073e9SAndroid Build Coastguard Worker        module.cuda()
131*da0073e9SAndroid Build Coastguard Worker        dp_module = torch.nn.DataParallel(module, [0, 1])
132*da0073e9SAndroid Build Coastguard Worker        self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module')
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
135*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
136*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
137*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
138*da0073e9SAndroid Build Coastguard Worker    ])
139*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_with_data_parallel_error(self, functional_call):
140*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
141*da0073e9SAndroid Build Coastguard Worker        module.cuda()
142*da0073e9SAndroid Build Coastguard Worker        dp_module = torch.nn.DataParallel(module, [0, 1])
143*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'):
144*da0073e9SAndroid Build Coastguard Worker            functional_call(
145*da0073e9SAndroid Build Coastguard Worker                dp_module,
146*da0073e9SAndroid Build Coastguard Worker                {'module.weight': torch.zeros(5, device='cuda')},
147*da0073e9SAndroid Build Coastguard Worker                (torch.ones(2, 5, device='cuda'),))
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
150*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
151*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
152*da0073e9SAndroid Build Coastguard Worker    ])
153*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_with_gradient(self, functional_call):
154*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
155*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1))
156*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[1.0]], requires_grad=True)
157*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([0.0], requires_grad=True)
158*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([0.0])
159*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
160*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
161*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
162*da0073e9SAndroid Build Coastguard Worker        res = functional_call(module, parameters, x)
163*da0073e9SAndroid Build Coastguard Worker        # Check that a backward step calculates the gradient of the supplied parameters
164*da0073e9SAndroid Build Coastguard Worker        res.backward()
165*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(weight.grad)
166*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(bias.grad)
167*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(buffer.grad)
168*da0073e9SAndroid Build Coastguard Worker        # Gradient was not calculated for the module stated and buffers
169*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.l1.weight.grad)
170*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.l1.bias.grad)
171*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.buffer.grad)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
174*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
175*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
176*da0073e9SAndroid Build Coastguard Worker    ])
177*da0073e9SAndroid Build Coastguard Worker    def test_functional_batch_norm(self, functional_call):
178*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.BatchNorm1d(10)
179*da0073e9SAndroid Build Coastguard Worker        module.train()  # Allow stats update
180*da0073e9SAndroid Build Coastguard Worker        # lets replace the running_mean buffer and check if its correctly updated
181*da0073e9SAndroid Build Coastguard Worker        x = torch.full((20, 10), 128.0)
182*da0073e9SAndroid Build Coastguard Worker        rm = torch.zeros(10)
183*da0073e9SAndroid Build Coastguard Worker        parameters = {'running_mean': rm}
184*da0073e9SAndroid Build Coastguard Worker        prev_rm = module.running_mean.clone()
185*da0073e9SAndroid Build Coastguard Worker        res = functional_call(module, parameters, x)
186*da0073e9SAndroid Build Coastguard Worker        cur_rm = module.running_mean
187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cur_rm, prev_rm)
188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rm, torch.full((10,), 12.8))
189*da0073e9SAndroid Build Coastguard Worker        # Now run functional without reparametrization and check that the module has
190*da0073e9SAndroid Build Coastguard Worker        # been updated
191*da0073e9SAndroid Build Coastguard Worker        res = functional_call(module, {}, x)
192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_mean, torch.full((10,), 12.8))
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
195*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
196*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
197*da0073e9SAndroid Build Coastguard Worker    ])
198*da0073e9SAndroid Build Coastguard Worker    def test_circular_references(self, functional_call):
199*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
200*da0073e9SAndroid Build Coastguard Worker        # Add a circular reference
201*da0073e9SAndroid Build Coastguard Worker        module.l1.m = module
202*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1))
203*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[1.0]])
204*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([0.0])
205*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([0.0])
206*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.m.l1.weight': weight,
207*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
208*da0073e9SAndroid Build Coastguard Worker                      'l1.m.buffer': buffer}
209*da0073e9SAndroid Build Coastguard Worker        prev_weight = module.l1.weight.clone()
210*da0073e9SAndroid Build Coastguard Worker        prev_buffer = module.buffer.clone()
211*da0073e9SAndroid Build Coastguard Worker        res = functional_call(module, parameters, x, tie_weights=False)
212*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, res)
213*da0073e9SAndroid Build Coastguard Worker        # check that the weights remain unmodified and were correctly accesed
214*da0073e9SAndroid Build Coastguard Worker        cur_weight = module.l1.weight
215*da0073e9SAndroid Build Coastguard Worker        cur_buffer = module.buffer
216*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cur_weight, prev_weight)
217*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cur_buffer, prev_buffer)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
220*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
221*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
222*da0073e9SAndroid Build Coastguard Worker    ])
223*da0073e9SAndroid Build Coastguard Worker    def test_reparametrized_module_change_parametrization_original(self, functional_call):
224*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
225*da0073e9SAndroid Build Coastguard Worker        torch.nn.utils.parametrizations.spectral_norm(module.l1)
226*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
227*da0073e9SAndroid Build Coastguard Worker        orig_sn_weight = module.l1.weight.clone()
228*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1))
229*da0073e9SAndroid Build Coastguard Worker        # We substitute the parameter inside the parametrization
230*da0073e9SAndroid Build Coastguard Worker        # the parametrization itself is not overwritten so it will be applied with a different
231*da0073e9SAndroid Build Coastguard Worker        # value for the original tensor
232*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
233*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': torch.tensor([0.0]),
234*da0073e9SAndroid Build Coastguard Worker                      'buffer': torch.tensor([0.0])}
235*da0073e9SAndroid Build Coastguard Worker        res = functional_call(module, parameters, x)
236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, res)
237*da0073e9SAndroid Build Coastguard Worker        # verify that the spectral normalization is still applied
238*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_sn_weight, module.l1.weight)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
242*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
243*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
244*da0073e9SAndroid Build Coastguard Worker    ])
245*da0073e9SAndroid Build Coastguard Worker    def test_reparametrize_module_fail_reset_to_original(self, functional_call):
246*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
247*da0073e9SAndroid Build Coastguard Worker        torch.nn.utils.parametrizations.spectral_norm(module.l1)
248*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
249*da0073e9SAndroid Build Coastguard Worker        orig_sn_weight = module.l1.weight.clone()
250*da0073e9SAndroid Build Coastguard Worker        # We substitute the parameter inside the parametrization
251*da0073e9SAndroid Build Coastguard Worker        # the parametrization itself is not overwritten so it will be applied with a different
252*da0073e9SAndroid Build Coastguard Worker        # value for the original tensor
253*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
254*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': torch.tensor([0.0]),
255*da0073e9SAndroid Build Coastguard Worker                      'buffer': torch.tensor([0.0])}
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"):
258*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
259*da0073e9SAndroid Build Coastguard Worker            def _error_case():
260*da0073e9SAndroid Build Coastguard Worker                x = torch.rand((4, 5))  # to work, it should be of size (1, 1)
261*da0073e9SAndroid Build Coastguard Worker                functional_call(module, parameters, x)  # this call will fail because x is the wrong size
262*da0073e9SAndroid Build Coastguard Worker            _error_case()
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        # verify that the spectral normalization is still applied
265*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_sn_weight, module.l1.weight)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
269*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
270*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
271*da0073e9SAndroid Build Coastguard Worker    ])
272*da0073e9SAndroid Build Coastguard Worker    def test_reparametrize_some_weights(self, functional_call):
273*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
274*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[2.0]])
275*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([5.0])
276*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([3.0])
277*da0073e9SAndroid Build Coastguard Worker        extra = torch.tensor([1.0])
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight}
280*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
281*da0073e9SAndroid Build Coastguard Worker        out = functional_call(module, parameters, x)
282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
285*da0073e9SAndroid Build Coastguard Worker                      'extra': extra}
286*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
287*da0073e9SAndroid Build Coastguard Worker        out = functional_call(module, parameters, x)
288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
291*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
292*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
293*da0073e9SAndroid Build Coastguard Worker    ])
294*da0073e9SAndroid Build Coastguard Worker    def test_reparametrize_strict(self, functional_call):
295*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
296*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[2.0]])
297*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([5.0])
298*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([3.0])
299*da0073e9SAndroid Build Coastguard Worker        extra = torch.tensor([1.0])
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        # All weights no error
302*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
303*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
304*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
305*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
306*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
307*da0073e9SAndroid Build Coastguard Worker            module,
308*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a successful call',
309*da0073e9SAndroid Build Coastguard Worker        ):
310*da0073e9SAndroid Build Coastguard Worker            out = functional_call(module, parameters, x, strict=True)
311*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, x * weight + bias + buffer)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        # Some weights
314*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight}
315*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
316*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
317*da0073e9SAndroid Build Coastguard Worker            module,
318*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
319*da0073e9SAndroid Build Coastguard Worker        ):
320*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
321*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
322*da0073e9SAndroid Build Coastguard Worker                re.escape("Missing key(s): 'buffer', 'l1.bias'."),
323*da0073e9SAndroid Build Coastguard Worker            ):
324*da0073e9SAndroid Build Coastguard Worker                out = functional_call(module, parameters, x, strict=True)
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        # Extra keys
327*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
328*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
329*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer,
330*da0073e9SAndroid Build Coastguard Worker                      'extra': extra}
331*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
332*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
333*da0073e9SAndroid Build Coastguard Worker            module,
334*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
335*da0073e9SAndroid Build Coastguard Worker        ):
336*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
337*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
338*da0073e9SAndroid Build Coastguard Worker                re.escape("Unexpected key(s): 'extra'."),
339*da0073e9SAndroid Build Coastguard Worker            ):
340*da0073e9SAndroid Build Coastguard Worker                out = functional_call(module, parameters, x, strict=True)
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        # Some weights with extra keys
343*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
344*da0073e9SAndroid Build Coastguard Worker                      'extra': extra}
345*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
346*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
347*da0073e9SAndroid Build Coastguard Worker            module,
348*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
349*da0073e9SAndroid Build Coastguard Worker        ):
350*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
351*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
352*da0073e9SAndroid Build Coastguard Worker                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."),
353*da0073e9SAndroid Build Coastguard Worker            ):
354*da0073e9SAndroid Build Coastguard Worker                out = functional_call(module, parameters, x, strict=True)
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
357*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
358*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
359*da0073e9SAndroid Build Coastguard Worker    ])
360*da0073e9SAndroid Build Coastguard Worker    def test_reparametrize_special(self, functional_call):
361*da0073e9SAndroid Build Coastguard Worker        class NonTensor:
362*da0073e9SAndroid Build Coastguard Worker            def __repr__(self):
363*da0073e9SAndroid Build Coastguard Worker                return f'<{self.__class__.__name__}>'
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
366*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[2.0]])
367*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([5.0])
368*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([3.0])
369*da0073e9SAndroid Build Coastguard Worker        non_tensor = NonTensor()
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        # Set to None
372*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
373*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': None,
374*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
375*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
376*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
377*da0073e9SAndroid Build Coastguard Worker            module,
378*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a successful call',
379*da0073e9SAndroid Build Coastguard Worker        ):
380*da0073e9SAndroid Build Coastguard Worker            out = functional_call(module, parameters, x)
381*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, x * weight + buffer)
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        # Set non-tensor
384*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': non_tensor}
385*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
386*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
387*da0073e9SAndroid Build Coastguard Worker            module,
388*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
389*da0073e9SAndroid Build Coastguard Worker        ):
390*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
391*da0073e9SAndroid Build Coastguard Worker                TypeError,
392*da0073e9SAndroid Build Coastguard Worker                re.escape("<NonTensor> is not an instance of torch.Tensor"),
393*da0073e9SAndroid Build Coastguard Worker            ):
394*da0073e9SAndroid Build Coastguard Worker                out = functional_call(module, parameters, x)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        # Set non-tensor attribute
397*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])}
398*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
399*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
400*da0073e9SAndroid Build Coastguard Worker            module,
401*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
402*da0073e9SAndroid Build Coastguard Worker        ):
403*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
404*da0073e9SAndroid Build Coastguard Worker                TypeError,
405*da0073e9SAndroid Build Coastguard Worker                re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"),
406*da0073e9SAndroid Build Coastguard Worker            ):
407*da0073e9SAndroid Build Coastguard Worker                out = functional_call(module, parameters, x)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        # Set non-exist submodule
410*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
411*da0073e9SAndroid Build Coastguard Worker                      'l2.bias': bias}
412*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
413*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
414*da0073e9SAndroid Build Coastguard Worker            module,
415*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
416*da0073e9SAndroid Build Coastguard Worker        ):
417*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
418*da0073e9SAndroid Build Coastguard Worker                AttributeError,
419*da0073e9SAndroid Build Coastguard Worker                re.escape("MockModule has no attribute `l2`"),
420*da0073e9SAndroid Build Coastguard Worker            ):
421*da0073e9SAndroid Build Coastguard Worker                out = functional_call(module, parameters, x)
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
424*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
425*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
426*da0073e9SAndroid Build Coastguard Worker    ])
427*da0073e9SAndroid Build Coastguard Worker    def test_tied_weights_warns(self, functional_call):
428*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
429*da0073e9SAndroid Build Coastguard Worker        module.tied_bias = module.l1.bias
430*da0073e9SAndroid Build Coastguard Worker        module.tied_buffer = torch.nn.Buffer(module.buffer)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
433*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
434*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
435*da0073e9SAndroid Build Coastguard Worker    ])
436*da0073e9SAndroid Build Coastguard Worker    def test_reparametrize_tie_weights(self, functional_call):
437*da0073e9SAndroid Build Coastguard Worker        module = MockTiedModule()
438*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[2.0]])
439*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([5.0])
440*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([3.0])
441*da0073e9SAndroid Build Coastguard Worker        extra = torch.tensor([1.0])
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
444*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
445*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
446*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
447*da0073e9SAndroid Build Coastguard Worker        out = functional_call(module, parameters, x, tie_weights=True)
448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
451*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
452*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer,
453*da0073e9SAndroid Build Coastguard Worker                      'extra': extra}
454*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
455*da0073e9SAndroid Build Coastguard Worker        out = functional_call(module, parameters, x, tie_weights=True)
456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
459*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
460*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
461*da0073e9SAndroid Build Coastguard Worker    ])
462*da0073e9SAndroid Build Coastguard Worker    def test_reparametrize_tie_some_weights(self, functional_call):
463*da0073e9SAndroid Build Coastguard Worker        module = MockTiedModule()
464*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[2.0]])
465*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([3.0])
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
468*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
469*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
470*da0073e9SAndroid Build Coastguard Worker        out = stateless.functional_call(module, parameters, x, tie_weights=True)
471*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
474*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
475*da0073e9SAndroid Build Coastguard Worker        subtest(stateless._functional_call, "stateless")
476*da0073e9SAndroid Build Coastguard Worker    ])
477*da0073e9SAndroid Build Coastguard Worker    def test_tied_weights_errors(self, functional_call):
478*da0073e9SAndroid Build Coastguard Worker        module = MockTiedModule()
479*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[1.0]])
480*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([0.0])
481*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([0.0])
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
484*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
485*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
486*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
487*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker        # if tied values are the same tensors, shouldn't warn
490*da0073e9SAndroid Build Coastguard Worker        parameters['tied_bias'] = bias
491*da0073e9SAndroid Build Coastguard Worker        parameters['tied_buffer'] = buffer
492*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
493*da0073e9SAndroid Build Coastguard Worker        del parameters['tied_bias']
494*da0073e9SAndroid Build Coastguard Worker        del parameters['tied_buffer']
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
497*da0073e9SAndroid Build Coastguard Worker            ValueError,
498*da0073e9SAndroid Build Coastguard Worker            re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"),
499*da0073e9SAndroid Build Coastguard Worker        ):
500*da0073e9SAndroid Build Coastguard Worker            parameters['tied_bias'] = torch.tensor([5.0])
501*da0073e9SAndroid Build Coastguard Worker            functional_call(module, parameters, x, tie_weights=True)
502*da0073e9SAndroid Build Coastguard Worker        del parameters['tied_bias']
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
505*da0073e9SAndroid Build Coastguard Worker            ValueError,
506*da0073e9SAndroid Build Coastguard Worker            re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"),
507*da0073e9SAndroid Build Coastguard Worker        ):
508*da0073e9SAndroid Build Coastguard Worker            parameters['tied_buffer'] = torch.tensor([5.0])
509*da0073e9SAndroid Build Coastguard Worker            functional_call(module, parameters, x, tie_weights=True)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    def test_tied_weights_no_error_without_flag(self):
512*da0073e9SAndroid Build Coastguard Worker        module = MockTiedModule()
513*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[1.0]])
514*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([0.0])
515*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([0.0])
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
518*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
519*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
520*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
521*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
522*da0073e9SAndroid Build Coastguard Worker        parameters['tied_bias'] = torch.tensor([5.0])
523*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
524*da0073e9SAndroid Build Coastguard Worker        del parameters['tied_bias']
525*da0073e9SAndroid Build Coastguard Worker        parameters['tied_buffer'] = torch.tensor([5.0])
526*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
529*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
530*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
531*da0073e9SAndroid Build Coastguard Worker    ])
532*da0073e9SAndroid Build Coastguard Worker    def test_reparametrize_tie_weights_strict(self, functional_call):
533*da0073e9SAndroid Build Coastguard Worker        module = MockTiedModule()
534*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[2.0]])
535*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([5.0])
536*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([3.0])
537*da0073e9SAndroid Build Coastguard Worker        extra = torch.tensor([1.0])
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        # Tie weights no error
540*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
541*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
542*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
543*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
544*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
545*da0073e9SAndroid Build Coastguard Worker            module,
546*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a successful call',
547*da0073e9SAndroid Build Coastguard Worker        ):
548*da0073e9SAndroid Build Coastguard Worker            out = functional_call(module, parameters, x, tie_weights=True, strict=True)
549*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker        # Tie weights without flag
552*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
553*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
554*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
555*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
556*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
557*da0073e9SAndroid Build Coastguard Worker            module,
558*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
559*da0073e9SAndroid Build Coastguard Worker        ):
560*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
561*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
562*da0073e9SAndroid Build Coastguard Worker                re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
563*da0073e9SAndroid Build Coastguard Worker            ):
564*da0073e9SAndroid Build Coastguard Worker                out = functional_call(module, parameters, x, tie_weights=False, strict=True)
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker        # Tie some weights
567*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
568*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
569*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
570*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
571*da0073e9SAndroid Build Coastguard Worker            module,
572*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
573*da0073e9SAndroid Build Coastguard Worker        ):
574*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
575*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
576*da0073e9SAndroid Build Coastguard Worker                re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
577*da0073e9SAndroid Build Coastguard Worker            ):
578*da0073e9SAndroid Build Coastguard Worker                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        # Tie weights with extra keys
581*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
582*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
583*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer,
584*da0073e9SAndroid Build Coastguard Worker                      'extra': extra}
585*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
586*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
587*da0073e9SAndroid Build Coastguard Worker            module,
588*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
589*da0073e9SAndroid Build Coastguard Worker        ):
590*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
591*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
592*da0073e9SAndroid Build Coastguard Worker                re.escape("Unexpected key(s): 'extra'."),
593*da0073e9SAndroid Build Coastguard Worker            ):
594*da0073e9SAndroid Build Coastguard Worker                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker        # Tie weights with extra keys and without flag
597*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
598*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
599*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer,
600*da0073e9SAndroid Build Coastguard Worker                      'extra': extra}
601*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
602*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
603*da0073e9SAndroid Build Coastguard Worker            module,
604*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
605*da0073e9SAndroid Build Coastguard Worker        ):
606*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
607*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
608*da0073e9SAndroid Build Coastguard Worker                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
609*da0073e9SAndroid Build Coastguard Worker            ):
610*da0073e9SAndroid Build Coastguard Worker                out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        # Tie some weights with extra keys
613*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
614*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer,
615*da0073e9SAndroid Build Coastguard Worker                      'extra': extra}
616*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
617*da0073e9SAndroid Build Coastguard Worker        with self._ensure_module_unchanged(
618*da0073e9SAndroid Build Coastguard Worker            module,
619*da0073e9SAndroid Build Coastguard Worker            'the module should not have been modified by a failed call',
620*da0073e9SAndroid Build Coastguard Worker        ):
621*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
622*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
623*da0073e9SAndroid Build Coastguard Worker                re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
624*da0073e9SAndroid Build Coastguard Worker            ):
625*da0073e9SAndroid Build Coastguard Worker                out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
628*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
629*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
630*da0073e9SAndroid Build Coastguard Worker    ])
631*da0073e9SAndroid Build Coastguard Worker    def test_setattr(self, functional_call):
632*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
633*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
634*da0073e9SAndroid Build Coastguard Worker                super().__init__()
635*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.nn.Buffer(torch.tensor([0.0]))
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
638*da0073e9SAndroid Build Coastguard Worker                self.foo = self.foo + 1
639*da0073e9SAndroid Build Coastguard Worker                return x + self.foo
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker        foo = torch.tensor([2.0])
642*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
643*da0073e9SAndroid Build Coastguard Worker        a = {'foo': foo}
644*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
645*da0073e9SAndroid Build Coastguard Worker        functional_call(mod, a, x)
646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod.foo, torch.tensor([0.0]))
647*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a['foo'], torch.tensor([3.0]))
648*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo, torch.tensor([2.0]))
649*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a['foo'] is not foo)
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
652*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
653*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
654*da0073e9SAndroid Build Coastguard Worker    ])
655*da0073e9SAndroid Build Coastguard Worker    def test_in_place_operator(self, functional_call):
656*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
657*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
658*da0073e9SAndroid Build Coastguard Worker                super().__init__()
659*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.nn.Buffer(torch.tensor([0.0]))
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
662*da0073e9SAndroid Build Coastguard Worker                self.foo.add_(1)
663*da0073e9SAndroid Build Coastguard Worker                return x + self.foo
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker        foo = torch.tensor([2.0])
666*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
667*da0073e9SAndroid Build Coastguard Worker        a = {'foo': foo}
668*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
669*da0073e9SAndroid Build Coastguard Worker        functional_call(mod, a, x)
670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod.foo, torch.tensor([0.0]))
671*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a['foo'], torch.tensor([3.0]))
672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo, torch.tensor([3.0]))
673*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a['foo'] is foo)
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
676*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
677*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
678*da0073e9SAndroid Build Coastguard Worker    ])
679*da0073e9SAndroid Build Coastguard Worker    def test_setattr_strict(self, functional_call):
680*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.nn.Module):
681*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
682*da0073e9SAndroid Build Coastguard Worker                super().__init__()
683*da0073e9SAndroid Build Coastguard Worker                assert not hasattr(self, 'extra')
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
686*da0073e9SAndroid Build Coastguard Worker                return x + self.extra
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker        a = {'extra': torch.zeros(())}
689*da0073e9SAndroid Build Coastguard Worker        mod = Bar()
690*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(mod, 'extra'))
691*da0073e9SAndroid Build Coastguard Worker        out = functional_call(mod, a, torch.ones(()))
692*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.ones(()))
693*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(mod, 'extra'))
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker        a = {'extra': torch.zeros(())}
696*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
697*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
698*da0073e9SAndroid Build Coastguard Worker            re.escape("Unexpected key(s): 'extra'."),
699*da0073e9SAndroid Build Coastguard Worker        ):
700*da0073e9SAndroid Build Coastguard Worker            out = functional_call(mod, a, torch.ones(()), strict=True)
701*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(mod, 'extra'))
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker        a = {}
704*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
705*da0073e9SAndroid Build Coastguard Worker            AttributeError,
706*da0073e9SAndroid Build Coastguard Worker            re.escape("'Bar' object has no attribute 'extra'"),
707*da0073e9SAndroid Build Coastguard Worker        ):
708*da0073e9SAndroid Build Coastguard Worker            out = functional_call(mod, a, torch.ones(()))
709*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(mod, 'extra'))
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker        a = {}
712*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
713*da0073e9SAndroid Build Coastguard Worker            AttributeError,
714*da0073e9SAndroid Build Coastguard Worker            re.escape("'Bar' object has no attribute 'extra'"),
715*da0073e9SAndroid Build Coastguard Worker        ):
716*da0073e9SAndroid Build Coastguard Worker            out = functional_call(mod, a, torch.ones(()), strict=True)
717*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(mod, 'extra'))
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
720*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
721*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
722*da0073e9SAndroid Build Coastguard Worker    ])
723*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_with_kwargs(self, functional_call):
724*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
725*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
726*da0073e9SAndroid Build Coastguard Worker                super().__init__()
727*da0073e9SAndroid Build Coastguard Worker                self.x = x
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp, *, other_inp):
730*da0073e9SAndroid Build Coastguard Worker                return inp * self.x + other_inp
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker        a = {'x': torch.zeros(2, 3)}
733*da0073e9SAndroid Build Coastguard Worker        mod = Foo(torch.randn(2, 3))
734*da0073e9SAndroid Build Coastguard Worker        inp, other_inp = torch.randn(2, 3), torch.randn(2, 3)
735*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "missing 1 required keyword-only argument: 'other_inp'"):
736*da0073e9SAndroid Build Coastguard Worker            functional_call(mod, a, inp)
737*da0073e9SAndroid Build Coastguard Worker        res = functional_call(mod, a, inp, {'other_inp': other_inp})
738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, other_inp)
739*da0073e9SAndroid Build Coastguard Worker        res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp})
740*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_1)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_tuple_dicts(self):
743*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
744*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1))
745*da0073e9SAndroid Build Coastguard Worker        parameters = {k: torch.ones_like(v) for k, v in mod.named_parameters()}
746*da0073e9SAndroid Build Coastguard Worker        buffers = {k: torch.zeros_like(v) for k, v in mod.named_buffers()}
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        # two dictionaries
749*da0073e9SAndroid Build Coastguard Worker        res = torch.func.functional_call(mod, (parameters, buffers), x)
750*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, x + 1)
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        # no dictionaries
753*da0073e9SAndroid Build Coastguard Worker        res = torch.func.functional_call(mod, (), x)
754*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, mod(x))
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker        # three dictonaries
757*da0073e9SAndroid Build Coastguard Worker        a = ({'l1.weight': torch.ones(1, 1)}, {'l1.bias': torch.ones(1)}, {'buffer': torch.zeros(1)})
758*da0073e9SAndroid Build Coastguard Worker        res = torch.func.functional_call(mod, a, x)
759*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, x + 1)
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_multiple_dicts_error(self):
762*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
763*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1))
764*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))}
765*da0073e9SAndroid Build Coastguard Worker        repeated_parameters = {'l1.weight': torch.ones((1, 1))}
766*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
767*da0073e9SAndroid Build Coastguard Worker            ValueError,
768*da0073e9SAndroid Build Coastguard Worker            re.escape("['l1.weight'] appeared in multiple dictionaries"),
769*da0073e9SAndroid Build Coastguard Worker        ):
770*da0073e9SAndroid Build Coastguard Worker            torch.func.functional_call(mod, (parameters, repeated_parameters), x)
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker    @parametrize("functional_call", [
773*da0073e9SAndroid Build Coastguard Worker        subtest(torch.func.functional_call, "torch_func"),
774*da0073e9SAndroid Build Coastguard Worker        subtest(stateless.functional_call, "stateless")
775*da0073e9SAndroid Build Coastguard Worker    ])
776*da0073e9SAndroid Build Coastguard Worker    def test_functional_call_member_reference(self, functional_call):
777*da0073e9SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
778*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
779*da0073e9SAndroid Build Coastguard Worker                super().__init__()
780*da0073e9SAndroid Build Coastguard Worker                self.l1 = torch.nn.Linear(1, 1)
781*da0073e9SAndroid Build Coastguard Worker                self.buffer = torch.nn.Buffer(torch.ones(1))
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
784*da0073e9SAndroid Build Coastguard Worker                parameters = tuple(self.parameters())
785*da0073e9SAndroid Build Coastguard Worker                buffers = tuple(self.buffers())
786*da0073e9SAndroid Build Coastguard Worker                return self.l1(x) + self.buffer, parameters, buffers
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker        module = Module()
789*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[2.0]])
790*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([5.0])
791*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([3.0])
792*da0073e9SAndroid Build Coastguard Worker        extra = torch.tensor([1.0])
793*da0073e9SAndroid Build Coastguard Worker        extra_p = torch.nn.Parameter(extra)
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker        # All weights
796*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
797*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
798*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
799*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
800*da0073e9SAndroid Build Coastguard Worker        out, parameters, buffers = functional_call(module, parameters, x)
801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + bias + buffer)
802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(parameters, (weight, bias))
803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffers, (buffer,))
804*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
805*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        # Some weights
808*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight}
809*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
810*da0073e9SAndroid Build Coastguard Worker        out, parameters, buffers = functional_call(module, parameters, x)
811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(parameters, (weight, module.l1.bias))
813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffers, (module.buffer,))
814*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
815*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        # All weights with extra keys
818*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
819*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
820*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer,
821*da0073e9SAndroid Build Coastguard Worker                      'l1.extra': extra}
822*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
823*da0073e9SAndroid Build Coastguard Worker        out, parameters, buffers = functional_call(module, parameters, x)
824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + bias + buffer)
825*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(parameters, (weight, bias))
826*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffers, (buffer,))
827*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
828*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker        # All weights with extra keys with parameters
831*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
832*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
833*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer,
834*da0073e9SAndroid Build Coastguard Worker                      'l1.extra': extra_p}
835*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
836*da0073e9SAndroid Build Coastguard Worker        out, parameters, buffers = functional_call(module, parameters, x)
837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + bias + buffer)
838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(parameters, (weight, bias, extra_p))
839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffers, (buffer,))
840*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p))))
841*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker        # Some weights with extra keys
844*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
845*da0073e9SAndroid Build Coastguard Worker                      'l1.extra': extra}
846*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
847*da0073e9SAndroid Build Coastguard Worker        out, parameters, buffers = functional_call(module, parameters, x)
848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(parameters, (weight, module.l1.bias))
850*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffers, (module.buffer))
851*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
852*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker        # Some weights with extra keys with parameters
855*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
856*da0073e9SAndroid Build Coastguard Worker                      'l1.extra': extra_p}
857*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
858*da0073e9SAndroid Build Coastguard Worker        out, parameters, buffers = functional_call(module, parameters, x)
859*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(parameters, (weight, module.l1.bias, extra_p))
861*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffers, (module.buffer))
862*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p))))
863*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker        # Set None
866*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
867*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': None}
868*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
869*da0073e9SAndroid Build Coastguard Worker        out, parameters, buffers = functional_call(module, parameters, x)
870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x * weight + module.buffer)
871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(parameters, (weight,))
872*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffers, (module.buffer))
873*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,))))
874*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Workerclass TestStatelessDeprecation(TestCase):
878*da0073e9SAndroid Build Coastguard Worker    def test_private_stateless_warns(self):
879*da0073e9SAndroid Build Coastguard Worker        script = """
880*da0073e9SAndroid Build Coastguard Workerimport torch
881*da0073e9SAndroid Build Coastguard Workerimport warnings
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Workerwith warnings.catch_warnings(record=True) as w:
884*da0073e9SAndroid Build Coastguard Worker    from torch.nn.utils import _stateless
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Workerexit(len(w))
887*da0073e9SAndroid Build Coastguard Worker"""
888*da0073e9SAndroid Build Coastguard Worker        try:
889*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(
890*da0073e9SAndroid Build Coastguard Worker                [sys.executable, '-W', 'always', '-c', script],
891*da0073e9SAndroid Build Coastguard Worker                stderr=subprocess.STDOUT,
892*da0073e9SAndroid Build Coastguard Worker                # On Windows, opening the subprocess with the default CWD makes `import torch`
893*da0073e9SAndroid Build Coastguard Worker                # fail, so just set CWD to this script's directory
894*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),)
895*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
896*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.returncode, 1)
897*da0073e9SAndroid Build Coastguard Worker        else:
898*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(False, "No warning was raised.")
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    def test_stateless_functional_call_warns(self):
901*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Linear(1, 1)
902*da0073e9SAndroid Build Coastguard Worker        params = dict(m.named_parameters())
903*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 1)
904*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(FutureWarning, "Please use `torch.func.functional_call`"):
905*da0073e9SAndroid Build Coastguard Worker            stateless.functional_call(m, params, x)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Workerclass TestPythonOptimizeMode(TestCase):
908*da0073e9SAndroid Build Coastguard Worker    def test_runs_with_optimize_flag(self):
909*da0073e9SAndroid Build Coastguard Worker        script = "import torch; import torch._functorch.deprecated"
910*da0073e9SAndroid Build Coastguard Worker        try:
911*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(
912*da0073e9SAndroid Build Coastguard Worker                [sys.executable, "-OO", "-c", script],
913*da0073e9SAndroid Build Coastguard Worker                stderr=subprocess.STDOUT,
914*da0073e9SAndroid Build Coastguard Worker                # On Windows, opening the subprocess with the default CWD makes `import torch`
915*da0073e9SAndroid Build Coastguard Worker                # fail, so just set CWD to this script's directory
916*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),)
917*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
918*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(e.returncode, "Import failed while running python in optimized mode")
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(
922*da0073e9SAndroid Build Coastguard Worker    TestStatelessFunctionalAPI,
923*da0073e9SAndroid Build Coastguard Worker)
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
926*da0073e9SAndroid Build Coastguard Worker    run_tests()
927