1# Owner(s): ["oncall: distributed"] 2 3# Copyright (c) Meta Platforms, Inc. and affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import unittest 10 11import torch 12import torch.nn as nn 13from torch.distributed.optim import _NamedOptimizer 14 15 16def _run_model_training(model_optim_lists): 17 for _ in range(2): 18 x = torch.rand(5, 8) 19 for model_optim_list in model_optim_lists: 20 model = model_optim_list[0] 21 optim_list = model_optim_list[1] 22 y = model(x) 23 y.sum().backward() 24 for optim in optim_list: 25 optim.step() 26 27 28class TestDummyModel(torch.nn.Module): 29 def __init__(self) -> None: 30 super().__init__() 31 torch.manual_seed(0) 32 self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) 33 self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) 34 self.net3 = nn.Linear(32, 64) 35 self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) 36 37 def forward(self, x): 38 return self.net4(self.net3(self.net2(self.net1(x)))) 39 40 41class NamedOptimizerTest(unittest.TestCase): 42 def _compare_state_dict_group(self, group, named_group, assert_equal=True): 43 for key, val in group.items(): 44 if key != "params": 45 self.assertTrue( 46 key in named_group, f"{key} not in named optimizer state dict" 47 ) 48 err_msg = ( 49 f"{key} state not equal" if assert_equal else f"{key} state equal" 50 ) 51 if isinstance(val, torch.Tensor): 52 fn = self.assertTrue if assert_equal else self.assertFalse 53 fn(torch.allclose(val, named_group[key]), err_msg) 54 else: 55 fn = self.assertEqual if assert_equal else self.assertNotEqual 56 fn(val, named_group[key], err_msg) 57 58 def _compare_param_groups(self, param_groups_1, param_groups_2): 59 self.assertTrue(isinstance(param_groups_1, list)) 60 self.assertTrue(isinstance(param_groups_2, list)) 61 for groups in zip(param_groups_1, param_groups_2): 62 self._compare_param_group(groups[0], groups[1]) 63 64 def _compare_param_group(self, group_1, group_2): 65 self.assertTrue(isinstance(group_1, dict)) 66 self.assertTrue(isinstance(group_2, dict)) 67 for key, val in group_1.items(): 68 self.assertTrue(key in group_2) 69 if key != "params": 70 self.assertEqual(val, group_2[key]) 71 else: 72 for tensors in zip(val, group_2[key]): 73 self.assertTrue(torch.allclose(tensors[0], tensors[1])) 74 75 def test_state_dict(self): 76 """Check that NamedOptimizer exposes the expected state dict 77 interface.""" 78 m = TestDummyModel() 79 m_dup = TestDummyModel() 80 optim = torch.optim.SGD( 81 m.parameters(), 82 lr=1e-2, 83 momentum=0.9, 84 ) 85 86 named_optim = _NamedOptimizer( 87 m_dup.named_parameters(), 88 torch.optim.SGD, 89 lr=1e-2, 90 momentum=0.9, 91 ) 92 self._compare_param_groups(optim.param_groups, named_optim.param_groups) 93 94 _run_model_training([(m, [optim]), (m_dup, [named_optim])]) 95 self._compare_param_groups(optim.param_groups, named_optim.param_groups) 96 97 sd = optim.state_dict() 98 named_sd = named_optim.state_dict() 99 100 # Compare "state" in optim state dict 101 self._compare_state_dict_group( 102 sd["state"][0], 103 named_sd["state"]["net1.0.weight"], 104 assert_equal=True, 105 ) 106 self._compare_state_dict_group( 107 sd["state"][3], 108 named_sd["state"]["net2.0.bias"], 109 assert_equal=True, 110 ) 111 self._compare_state_dict_group( 112 sd["state"][4], 113 named_sd["state"]["net3.weight"], 114 assert_equal=True, 115 ) 116 self._compare_state_dict_group( 117 sd["state"][7], 118 named_sd["state"]["net4.1.bias"], 119 assert_equal=True, 120 ) 121 122 def test_state_dict_multi_param_group(self): 123 """Check that NamedOptimizer exposes the expected state dict 124 interface when multiple param groups are specified.""" 125 m = TestDummyModel() 126 m_dup = TestDummyModel() 127 optim_1 = torch.optim.SGD( 128 [ 129 {"params": m.net1.parameters()}, 130 {"params": m.net3.parameters(), "lr": 1e-3}, 131 ], 132 lr=1e-2, 133 momentum=0.9, 134 ) 135 136 optim_2 = torch.optim.Adam( 137 [ 138 {"params": m.net2.parameters()}, 139 {"params": m.net4.parameters(), "lr": 1e-5}, 140 ] 141 ) 142 143 named_optim_1 = _NamedOptimizer( 144 m_dup.named_parameters(), 145 torch.optim.SGD, 146 [ 147 {"params": m_dup.net1.parameters()}, 148 {"params": m_dup.net3.parameters(), "lr": 1e-3}, 149 ], 150 lr=1e-2, 151 momentum=0.9, 152 ) 153 154 named_optim_2 = _NamedOptimizer( 155 m_dup.named_parameters(), 156 torch.optim.Adam, 157 [ 158 {"params": m_dup.net2.parameters()}, 159 {"params": m_dup.net4.parameters(), "lr": 1e-5}, 160 ], 161 ) 162 self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups) 163 self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups) 164 165 _run_model_training( 166 [(m, [optim_1, optim_2]), (m_dup, [named_optim_1, named_optim_2])] 167 ) 168 self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups) 169 self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups) 170 sd_1 = optim_1.state_dict() 171 sd_2 = optim_2.state_dict() 172 named_sd_1 = named_optim_1.state_dict() 173 named_sd_2 = named_optim_2.state_dict() 174 175 # Compare "state" in optim state dict 176 self._compare_state_dict_group( 177 sd_1["state"][0], 178 named_sd_1["state"]["net1.0.weight"], 179 assert_equal=True, 180 ) 181 self._compare_state_dict_group( 182 sd_2["state"][1], 183 named_sd_2["state"]["net2.0.bias"], 184 assert_equal=True, 185 ) 186 self._compare_state_dict_group( 187 sd_1["state"][2], 188 named_sd_1["state"]["net3.weight"], 189 assert_equal=True, 190 ) 191 self._compare_state_dict_group( 192 sd_2["state"][3], 193 named_sd_2["state"]["net4.1.bias"], 194 assert_equal=True, 195 ) 196 197 # Compare "param_groups" in optim state dict 198 self._compare_state_dict_group( 199 sd_1["param_groups"][0], 200 named_sd_1["param_groups"][0], 201 assert_equal=True, 202 ) 203 self._compare_state_dict_group( 204 sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True 205 ) 206 207 def test_load_state_dict(self): 208 """Check that NamedOptimizer's load_state_dict works as expected.""" 209 m = TestDummyModel() 210 named_optim_1 = _NamedOptimizer( 211 m.named_parameters(), 212 torch.optim.SGD, 213 lr=1e-2, 214 momentum=0.9, 215 ) 216 217 _run_model_training([(m, [named_optim_1])]) 218 state_dict_to_load = named_optim_1.state_dict() 219 220 named_optim_2 = _NamedOptimizer( 221 m.named_parameters(), 222 torch.optim.SGD, 223 lr=1e-2, 224 momentum=0.6, 225 ) 226 227 _run_model_training([(m, [named_optim_2])]) 228 state_dict_before_load = named_optim_2.state_dict() 229 230 # Compare "state" in optim state dict 231 self._compare_state_dict_group( 232 state_dict_to_load["state"]["net1.0.weight"], 233 state_dict_before_load["state"]["net1.0.weight"], 234 assert_equal=False, 235 ) 236 self._compare_state_dict_group( 237 state_dict_to_load["state"]["net2.0.bias"], 238 state_dict_before_load["state"]["net2.0.bias"], 239 assert_equal=False, 240 ) 241 self._compare_state_dict_group( 242 state_dict_to_load["state"]["net3.weight"], 243 state_dict_before_load["state"]["net3.weight"], 244 assert_equal=False, 245 ) 246 self._compare_state_dict_group( 247 state_dict_to_load["state"]["net4.1.bias"], 248 state_dict_before_load["state"]["net4.1.bias"], 249 assert_equal=False, 250 ) 251 252 named_optim_2.load_state_dict(state_dict_to_load) 253 state_dict_after_load = named_optim_2.state_dict() 254 255 # Compare "state" in optim state dict 256 self._compare_state_dict_group( 257 state_dict_to_load["state"]["net1.0.weight"], 258 state_dict_after_load["state"]["net1.0.weight"], 259 assert_equal=True, 260 ) 261 self._compare_state_dict_group( 262 state_dict_to_load["state"]["net2.0.bias"], 263 state_dict_after_load["state"]["net2.0.bias"], 264 assert_equal=True, 265 ) 266 self._compare_state_dict_group( 267 state_dict_to_load["state"]["net3.weight"], 268 state_dict_after_load["state"]["net3.weight"], 269 assert_equal=True, 270 ) 271 self._compare_state_dict_group( 272 state_dict_to_load["state"]["net4.1.bias"], 273 state_dict_after_load["state"]["net4.1.bias"], 274 assert_equal=True, 275 ) 276 277 def test_load_state_dict_conditional_training(self): 278 """Check that NamedOptimizer load_state_dict works under conditional training case.""" 279 m = TestDummyModel() 280 named_optim_1 = _NamedOptimizer( 281 m.named_parameters(), 282 torch.optim.SGD, 283 [ 284 {"params": m.net1.parameters()}, 285 {"params": m.net3.parameters(), "lr": 1e-3}, 286 ], 287 lr=1e-2, 288 momentum=0.9, 289 ) 290 291 _run_model_training([(m, [named_optim_1])]) 292 state_dict_to_load = named_optim_1.state_dict() 293 294 named_optim_2 = _NamedOptimizer( 295 m.named_parameters(), 296 torch.optim.SGD, 297 lr=1e-2, 298 momentum=0.6, 299 ) 300 301 _run_model_training([(m, [named_optim_2])]) 302 named_optim_2.load_state_dict(state_dict_to_load) 303 state_dict_after_load = named_optim_2.state_dict() 304 305 # Compare "state" in optim state dict 306 self._compare_state_dict_group( 307 state_dict_to_load["state"]["net1.0.weight"], 308 state_dict_after_load["state"]["net1.0.weight"], 309 assert_equal=True, 310 ) 311 self._compare_state_dict_group( 312 state_dict_to_load["state"]["net3.weight"], 313 state_dict_after_load["state"]["net3.weight"], 314 assert_equal=True, 315 ) 316 317 def test_load_state_dict_error(self): 318 m = TestDummyModel() 319 named_optim_1 = _NamedOptimizer( 320 m.named_parameters(), 321 torch.optim.SGD, 322 lr=1e-2, 323 momentum=0.9, 324 ) 325 326 _run_model_training([(m, [named_optim_1])]) 327 state_dict_to_load = named_optim_1.state_dict() 328 329 named_optim_2 = _NamedOptimizer( 330 m.named_parameters(), 331 torch.optim.SGD, 332 lr=1e-2, 333 momentum=0.6, 334 ) 335 336 err_msg = ( 337 "Expects the optim to be initialized before load but found not initialized" 338 ) 339 with self.assertRaisesRegex(ValueError, err_msg): 340 named_optim_2.load_state_dict(state_dict_to_load) 341 342 def test_add_param_group(self): 343 m = TestDummyModel() 344 m_dup = TestDummyModel() 345 optim = torch.optim.SGD( 346 [ 347 {"params": m.net1.parameters()}, 348 {"params": m.net3.parameters(), "lr": 1e-3}, 349 ], 350 lr=1e-2, 351 momentum=0.9, 352 ) 353 named_optim = _NamedOptimizer( 354 m_dup.named_parameters(), 355 torch.optim.SGD, 356 [ 357 {"params": m_dup.net1.parameters()}, 358 {"params": m_dup.net3.parameters(), "lr": 1e-3}, 359 ], 360 lr=1e-2, 361 momentum=0.9, 362 ) 363 364 _run_model_training([(m, [optim]), (m_dup, [named_optim])]) 365 self._compare_param_groups(optim.param_groups, named_optim.param_groups) 366 367 optim.add_param_group({"params": m.net2.parameters(), "lr": 1e-5}) 368 named_optim.add_param_group({"params": m_dup.net2.parameters(), "lr": 1e-5}) 369 _run_model_training([(m, [optim]), (m_dup, [named_optim])]) 370 self._compare_param_groups(optim.param_groups, named_optim.param_groups) 371 372 optim.add_param_group({"params": m.net4[1].weight, "lr": 1e-3}) 373 named_optim.add_param_group({"params": m_dup.net4[1].weight, "lr": 1e-3}) 374 _run_model_training([(m, [optim]), (m_dup, [named_optim])]) 375 self._compare_param_groups(optim.param_groups, named_optim.param_groups) 376 377 def test_add_param_group_error(self): 378 m = TestDummyModel() 379 named_optim = _NamedOptimizer( 380 m.named_parameters(), 381 torch.optim.SGD, 382 [ 383 {"params": m.net1.parameters()}, 384 {"params": m.net3.parameters(), "lr": 1e-3}, 385 ], 386 lr=1e-2, 387 momentum=0.9, 388 ) 389 390 err_msg = "some parameters are not in the module" 391 with self.assertRaisesRegex(ValueError, err_msg): 392 named_optim.add_param_group({"params": [torch.ones(8, 1)], "lr": 1e-5}) 393 394 def test_init_state(self): 395 m = TestDummyModel() 396 named_optim = _NamedOptimizer( 397 m.named_parameters(), 398 torch.optim.SGD, 399 [ 400 {"params": m.net1.parameters()}, 401 {"params": m.net3.parameters(), "lr": 1e-3}, 402 ], 403 lr=1e-2, 404 momentum=0.9, 405 ) 406 named_sd = named_optim.state_dict() 407 self.assertTrue(m.net1[0].weight.grad is None) 408 self.assertTrue(len(named_sd["state"]) == 0) 409 named_optim.init_state() 410 named_sd = named_optim.state_dict() 411 self.assertTrue(m.net1[0].weight.grad is not None) 412 self.assertTrue("momentum_buffer" in named_sd["state"]["net1.0.weight"]) 413 self.assertFalse( 414 torch.all(named_sd["state"]["net1.0.weight"]["momentum_buffer"]).item() 415 ) 416 self.assertFalse( 417 torch.all(named_sd["state"]["net1.0.bias"]["momentum_buffer"]).item() 418 ) 419 self.assertTrue(m.net3.bias.grad is not None) 420 self.assertTrue("momentum_buffer" in named_sd["state"]["net3.bias"]) 421 self.assertFalse( 422 torch.all(named_sd["state"]["net3.bias"]["momentum_buffer"]).item() 423 ) 424 self.assertFalse( 425 torch.all(named_sd["state"]["net3.weight"]["momentum_buffer"]).item() 426 ) 427