xref: /aosp_15_r20/external/pytorch/test/distributed/optim/test_named_optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3# Copyright (c) Meta Platforms, Inc. and affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import unittest
10
11import torch
12import torch.nn as nn
13from torch.distributed.optim import _NamedOptimizer
14
15
16def _run_model_training(model_optim_lists):
17    for _ in range(2):
18        x = torch.rand(5, 8)
19        for model_optim_list in model_optim_lists:
20            model = model_optim_list[0]
21            optim_list = model_optim_list[1]
22            y = model(x)
23            y.sum().backward()
24            for optim in optim_list:
25                optim.step()
26
27
28class TestDummyModel(torch.nn.Module):
29    def __init__(self) -> None:
30        super().__init__()
31        torch.manual_seed(0)
32        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
33        self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
34        self.net3 = nn.Linear(32, 64)
35        self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
36
37    def forward(self, x):
38        return self.net4(self.net3(self.net2(self.net1(x))))
39
40
41class NamedOptimizerTest(unittest.TestCase):
42    def _compare_state_dict_group(self, group, named_group, assert_equal=True):
43        for key, val in group.items():
44            if key != "params":
45                self.assertTrue(
46                    key in named_group, f"{key} not in named optimizer state dict"
47                )
48                err_msg = (
49                    f"{key} state not equal" if assert_equal else f"{key} state equal"
50                )
51                if isinstance(val, torch.Tensor):
52                    fn = self.assertTrue if assert_equal else self.assertFalse
53                    fn(torch.allclose(val, named_group[key]), err_msg)
54                else:
55                    fn = self.assertEqual if assert_equal else self.assertNotEqual
56                    fn(val, named_group[key], err_msg)
57
58    def _compare_param_groups(self, param_groups_1, param_groups_2):
59        self.assertTrue(isinstance(param_groups_1, list))
60        self.assertTrue(isinstance(param_groups_2, list))
61        for groups in zip(param_groups_1, param_groups_2):
62            self._compare_param_group(groups[0], groups[1])
63
64    def _compare_param_group(self, group_1, group_2):
65        self.assertTrue(isinstance(group_1, dict))
66        self.assertTrue(isinstance(group_2, dict))
67        for key, val in group_1.items():
68            self.assertTrue(key in group_2)
69            if key != "params":
70                self.assertEqual(val, group_2[key])
71            else:
72                for tensors in zip(val, group_2[key]):
73                    self.assertTrue(torch.allclose(tensors[0], tensors[1]))
74
75    def test_state_dict(self):
76        """Check that NamedOptimizer exposes the expected state dict
77        interface."""
78        m = TestDummyModel()
79        m_dup = TestDummyModel()
80        optim = torch.optim.SGD(
81            m.parameters(),
82            lr=1e-2,
83            momentum=0.9,
84        )
85
86        named_optim = _NamedOptimizer(
87            m_dup.named_parameters(),
88            torch.optim.SGD,
89            lr=1e-2,
90            momentum=0.9,
91        )
92        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
93
94        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
95        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
96
97        sd = optim.state_dict()
98        named_sd = named_optim.state_dict()
99
100        # Compare "state" in optim state dict
101        self._compare_state_dict_group(
102            sd["state"][0],
103            named_sd["state"]["net1.0.weight"],
104            assert_equal=True,
105        )
106        self._compare_state_dict_group(
107            sd["state"][3],
108            named_sd["state"]["net2.0.bias"],
109            assert_equal=True,
110        )
111        self._compare_state_dict_group(
112            sd["state"][4],
113            named_sd["state"]["net3.weight"],
114            assert_equal=True,
115        )
116        self._compare_state_dict_group(
117            sd["state"][7],
118            named_sd["state"]["net4.1.bias"],
119            assert_equal=True,
120        )
121
122    def test_state_dict_multi_param_group(self):
123        """Check that NamedOptimizer exposes the expected state dict
124        interface when multiple param groups are specified."""
125        m = TestDummyModel()
126        m_dup = TestDummyModel()
127        optim_1 = torch.optim.SGD(
128            [
129                {"params": m.net1.parameters()},
130                {"params": m.net3.parameters(), "lr": 1e-3},
131            ],
132            lr=1e-2,
133            momentum=0.9,
134        )
135
136        optim_2 = torch.optim.Adam(
137            [
138                {"params": m.net2.parameters()},
139                {"params": m.net4.parameters(), "lr": 1e-5},
140            ]
141        )
142
143        named_optim_1 = _NamedOptimizer(
144            m_dup.named_parameters(),
145            torch.optim.SGD,
146            [
147                {"params": m_dup.net1.parameters()},
148                {"params": m_dup.net3.parameters(), "lr": 1e-3},
149            ],
150            lr=1e-2,
151            momentum=0.9,
152        )
153
154        named_optim_2 = _NamedOptimizer(
155            m_dup.named_parameters(),
156            torch.optim.Adam,
157            [
158                {"params": m_dup.net2.parameters()},
159                {"params": m_dup.net4.parameters(), "lr": 1e-5},
160            ],
161        )
162        self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
163        self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
164
165        _run_model_training(
166            [(m, [optim_1, optim_2]), (m_dup, [named_optim_1, named_optim_2])]
167        )
168        self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
169        self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
170        sd_1 = optim_1.state_dict()
171        sd_2 = optim_2.state_dict()
172        named_sd_1 = named_optim_1.state_dict()
173        named_sd_2 = named_optim_2.state_dict()
174
175        # Compare "state" in optim state dict
176        self._compare_state_dict_group(
177            sd_1["state"][0],
178            named_sd_1["state"]["net1.0.weight"],
179            assert_equal=True,
180        )
181        self._compare_state_dict_group(
182            sd_2["state"][1],
183            named_sd_2["state"]["net2.0.bias"],
184            assert_equal=True,
185        )
186        self._compare_state_dict_group(
187            sd_1["state"][2],
188            named_sd_1["state"]["net3.weight"],
189            assert_equal=True,
190        )
191        self._compare_state_dict_group(
192            sd_2["state"][3],
193            named_sd_2["state"]["net4.1.bias"],
194            assert_equal=True,
195        )
196
197        # Compare "param_groups" in optim state dict
198        self._compare_state_dict_group(
199            sd_1["param_groups"][0],
200            named_sd_1["param_groups"][0],
201            assert_equal=True,
202        )
203        self._compare_state_dict_group(
204            sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True
205        )
206
207    def test_load_state_dict(self):
208        """Check that NamedOptimizer's load_state_dict works as expected."""
209        m = TestDummyModel()
210        named_optim_1 = _NamedOptimizer(
211            m.named_parameters(),
212            torch.optim.SGD,
213            lr=1e-2,
214            momentum=0.9,
215        )
216
217        _run_model_training([(m, [named_optim_1])])
218        state_dict_to_load = named_optim_1.state_dict()
219
220        named_optim_2 = _NamedOptimizer(
221            m.named_parameters(),
222            torch.optim.SGD,
223            lr=1e-2,
224            momentum=0.6,
225        )
226
227        _run_model_training([(m, [named_optim_2])])
228        state_dict_before_load = named_optim_2.state_dict()
229
230        # Compare "state" in optim state dict
231        self._compare_state_dict_group(
232            state_dict_to_load["state"]["net1.0.weight"],
233            state_dict_before_load["state"]["net1.0.weight"],
234            assert_equal=False,
235        )
236        self._compare_state_dict_group(
237            state_dict_to_load["state"]["net2.0.bias"],
238            state_dict_before_load["state"]["net2.0.bias"],
239            assert_equal=False,
240        )
241        self._compare_state_dict_group(
242            state_dict_to_load["state"]["net3.weight"],
243            state_dict_before_load["state"]["net3.weight"],
244            assert_equal=False,
245        )
246        self._compare_state_dict_group(
247            state_dict_to_load["state"]["net4.1.bias"],
248            state_dict_before_load["state"]["net4.1.bias"],
249            assert_equal=False,
250        )
251
252        named_optim_2.load_state_dict(state_dict_to_load)
253        state_dict_after_load = named_optim_2.state_dict()
254
255        # Compare "state" in optim state dict
256        self._compare_state_dict_group(
257            state_dict_to_load["state"]["net1.0.weight"],
258            state_dict_after_load["state"]["net1.0.weight"],
259            assert_equal=True,
260        )
261        self._compare_state_dict_group(
262            state_dict_to_load["state"]["net2.0.bias"],
263            state_dict_after_load["state"]["net2.0.bias"],
264            assert_equal=True,
265        )
266        self._compare_state_dict_group(
267            state_dict_to_load["state"]["net3.weight"],
268            state_dict_after_load["state"]["net3.weight"],
269            assert_equal=True,
270        )
271        self._compare_state_dict_group(
272            state_dict_to_load["state"]["net4.1.bias"],
273            state_dict_after_load["state"]["net4.1.bias"],
274            assert_equal=True,
275        )
276
277    def test_load_state_dict_conditional_training(self):
278        """Check that NamedOptimizer load_state_dict works under conditional training case."""
279        m = TestDummyModel()
280        named_optim_1 = _NamedOptimizer(
281            m.named_parameters(),
282            torch.optim.SGD,
283            [
284                {"params": m.net1.parameters()},
285                {"params": m.net3.parameters(), "lr": 1e-3},
286            ],
287            lr=1e-2,
288            momentum=0.9,
289        )
290
291        _run_model_training([(m, [named_optim_1])])
292        state_dict_to_load = named_optim_1.state_dict()
293
294        named_optim_2 = _NamedOptimizer(
295            m.named_parameters(),
296            torch.optim.SGD,
297            lr=1e-2,
298            momentum=0.6,
299        )
300
301        _run_model_training([(m, [named_optim_2])])
302        named_optim_2.load_state_dict(state_dict_to_load)
303        state_dict_after_load = named_optim_2.state_dict()
304
305        # Compare "state" in optim state dict
306        self._compare_state_dict_group(
307            state_dict_to_load["state"]["net1.0.weight"],
308            state_dict_after_load["state"]["net1.0.weight"],
309            assert_equal=True,
310        )
311        self._compare_state_dict_group(
312            state_dict_to_load["state"]["net3.weight"],
313            state_dict_after_load["state"]["net3.weight"],
314            assert_equal=True,
315        )
316
317    def test_load_state_dict_error(self):
318        m = TestDummyModel()
319        named_optim_1 = _NamedOptimizer(
320            m.named_parameters(),
321            torch.optim.SGD,
322            lr=1e-2,
323            momentum=0.9,
324        )
325
326        _run_model_training([(m, [named_optim_1])])
327        state_dict_to_load = named_optim_1.state_dict()
328
329        named_optim_2 = _NamedOptimizer(
330            m.named_parameters(),
331            torch.optim.SGD,
332            lr=1e-2,
333            momentum=0.6,
334        )
335
336        err_msg = (
337            "Expects the optim to be initialized before load but found not initialized"
338        )
339        with self.assertRaisesRegex(ValueError, err_msg):
340            named_optim_2.load_state_dict(state_dict_to_load)
341
342    def test_add_param_group(self):
343        m = TestDummyModel()
344        m_dup = TestDummyModel()
345        optim = torch.optim.SGD(
346            [
347                {"params": m.net1.parameters()},
348                {"params": m.net3.parameters(), "lr": 1e-3},
349            ],
350            lr=1e-2,
351            momentum=0.9,
352        )
353        named_optim = _NamedOptimizer(
354            m_dup.named_parameters(),
355            torch.optim.SGD,
356            [
357                {"params": m_dup.net1.parameters()},
358                {"params": m_dup.net3.parameters(), "lr": 1e-3},
359            ],
360            lr=1e-2,
361            momentum=0.9,
362        )
363
364        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
365        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
366
367        optim.add_param_group({"params": m.net2.parameters(), "lr": 1e-5})
368        named_optim.add_param_group({"params": m_dup.net2.parameters(), "lr": 1e-5})
369        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
370        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
371
372        optim.add_param_group({"params": m.net4[1].weight, "lr": 1e-3})
373        named_optim.add_param_group({"params": m_dup.net4[1].weight, "lr": 1e-3})
374        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
375        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
376
377    def test_add_param_group_error(self):
378        m = TestDummyModel()
379        named_optim = _NamedOptimizer(
380            m.named_parameters(),
381            torch.optim.SGD,
382            [
383                {"params": m.net1.parameters()},
384                {"params": m.net3.parameters(), "lr": 1e-3},
385            ],
386            lr=1e-2,
387            momentum=0.9,
388        )
389
390        err_msg = "some parameters are not in the module"
391        with self.assertRaisesRegex(ValueError, err_msg):
392            named_optim.add_param_group({"params": [torch.ones(8, 1)], "lr": 1e-5})
393
394    def test_init_state(self):
395        m = TestDummyModel()
396        named_optim = _NamedOptimizer(
397            m.named_parameters(),
398            torch.optim.SGD,
399            [
400                {"params": m.net1.parameters()},
401                {"params": m.net3.parameters(), "lr": 1e-3},
402            ],
403            lr=1e-2,
404            momentum=0.9,
405        )
406        named_sd = named_optim.state_dict()
407        self.assertTrue(m.net1[0].weight.grad is None)
408        self.assertTrue(len(named_sd["state"]) == 0)
409        named_optim.init_state()
410        named_sd = named_optim.state_dict()
411        self.assertTrue(m.net1[0].weight.grad is not None)
412        self.assertTrue("momentum_buffer" in named_sd["state"]["net1.0.weight"])
413        self.assertFalse(
414            torch.all(named_sd["state"]["net1.0.weight"]["momentum_buffer"]).item()
415        )
416        self.assertFalse(
417            torch.all(named_sd["state"]["net1.0.bias"]["momentum_buffer"]).item()
418        )
419        self.assertTrue(m.net3.bias.grad is not None)
420        self.assertTrue("momentum_buffer" in named_sd["state"]["net3.bias"])
421        self.assertFalse(
422            torch.all(named_sd["state"]["net3.bias"]["momentum_buffer"]).item()
423        )
424        self.assertFalse(
425            torch.all(named_sd["state"]["net3.weight"]["momentum_buffer"]).item()
426        )
427