1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: optimizer", "module: LrScheduler" ] 2*da0073e9SAndroid Build Coastguard Workerimport copy 3*da0073e9SAndroid Build Coastguard Workerimport math 4*da0073e9SAndroid Build Coastguard Workerimport pickle 5*da0073e9SAndroid Build Coastguard Workerimport tempfile 6*da0073e9SAndroid Build Coastguard Workerimport types 7*da0073e9SAndroid Build Coastguard Workerimport warnings 8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter 13*da0073e9SAndroid Build Coastguard Workerfrom torch.optim import Adam, Rprop, SGD 14*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.lr_scheduler import ( 15*da0073e9SAndroid Build Coastguard Worker ChainedScheduler, 16*da0073e9SAndroid Build Coastguard Worker ConstantLR, 17*da0073e9SAndroid Build Coastguard Worker CosineAnnealingLR, 18*da0073e9SAndroid Build Coastguard Worker CosineAnnealingWarmRestarts, 19*da0073e9SAndroid Build Coastguard Worker CyclicLR, 20*da0073e9SAndroid Build Coastguard Worker EPOCH_DEPRECATION_WARNING, 21*da0073e9SAndroid Build Coastguard Worker ExponentialLR, 22*da0073e9SAndroid Build Coastguard Worker LambdaLR, 23*da0073e9SAndroid Build Coastguard Worker LinearLR, 24*da0073e9SAndroid Build Coastguard Worker LRScheduler, 25*da0073e9SAndroid Build Coastguard Worker MultiplicativeLR, 26*da0073e9SAndroid Build Coastguard Worker MultiStepLR, 27*da0073e9SAndroid Build Coastguard Worker OneCycleLR, 28*da0073e9SAndroid Build Coastguard Worker PolynomialLR, 29*da0073e9SAndroid Build Coastguard Worker ReduceLROnPlateau, 30*da0073e9SAndroid Build Coastguard Worker SequentialLR, 31*da0073e9SAndroid Build Coastguard Worker StepLR, 32*da0073e9SAndroid Build Coastguard Worker) 33*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.swa_utils import SWALR 34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 35*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 36*da0073e9SAndroid Build Coastguard Worker load_tests, 37*da0073e9SAndroid Build Coastguard Worker parametrize, 38*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 39*da0073e9SAndroid Build Coastguard Worker TestCase, 40*da0073e9SAndroid Build Coastguard Worker) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for 44*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 45*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerclass TestLRScheduler(TestCase): 49*da0073e9SAndroid Build Coastguard Worker class SchedulerTestNet(torch.nn.Module): 50*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 51*da0073e9SAndroid Build Coastguard Worker super().__init__() 52*da0073e9SAndroid Build Coastguard Worker self.conv1 = torch.nn.Conv2d(1, 1, 1) 53*da0073e9SAndroid Build Coastguard Worker self.conv2 = torch.nn.Conv2d(1, 1, 1) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 56*da0073e9SAndroid Build Coastguard Worker return self.conv2(F.relu(self.conv1(x))) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker class LambdaLRTestObject: 59*da0073e9SAndroid Build Coastguard Worker def __init__(self, value): 60*da0073e9SAndroid Build Coastguard Worker self.value = value 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def __call__(self, epoch): 63*da0073e9SAndroid Build Coastguard Worker return self.value * epoch 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 66*da0073e9SAndroid Build Coastguard Worker if isinstance(other, self.__class__): 67*da0073e9SAndroid Build Coastguard Worker return self.__dict__ == other.__dict__ 68*da0073e9SAndroid Build Coastguard Worker else: 69*da0073e9SAndroid Build Coastguard Worker return False 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def setUp(self): 74*da0073e9SAndroid Build Coastguard Worker super().setUp() 75*da0073e9SAndroid Build Coastguard Worker self.net = self.SchedulerTestNet() 76*da0073e9SAndroid Build Coastguard Worker self.opt = SGD( 77*da0073e9SAndroid Build Coastguard Worker [ 78*da0073e9SAndroid Build Coastguard Worker {"params": self.net.conv1.parameters()}, 79*da0073e9SAndroid Build Coastguard Worker {"params": self.net.conv2.parameters(), "lr": 0.5}, 80*da0073e9SAndroid Build Coastguard Worker ], 81*da0073e9SAndroid Build Coastguard Worker lr=0.05, 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1): 85*da0073e9SAndroid Build Coastguard Worker """This function swallows the epoch deprecation warning which is produced when we 86*da0073e9SAndroid Build Coastguard Worker call `scheduler.step(epoch)` with some not `None` value of `epoch`. 87*da0073e9SAndroid Build Coastguard Worker this is deprecated, and this function will need to be removed/updated when 88*da0073e9SAndroid Build Coastguard Worker the schedulers no longer accept the parameter at all. 89*da0073e9SAndroid Build Coastguard Worker """ 90*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), num_warnings) 91*da0073e9SAndroid Build Coastguard Worker for warning in w: 92*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(warning.message.args), 1) 93*da0073e9SAndroid Build Coastguard Worker self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def test_error_when_getlr_has_epoch(self): 96*da0073e9SAndroid Build Coastguard Worker class MultiStepLR(torch.optim.lr_scheduler.LRScheduler): 97*da0073e9SAndroid Build Coastguard Worker def __init__(self, optimizer, gamma, milestones, last_epoch=-1): 98*da0073e9SAndroid Build Coastguard Worker self.init_lr = [group["lr"] for group in optimizer.param_groups] 99*da0073e9SAndroid Build Coastguard Worker self.gamma = gamma 100*da0073e9SAndroid Build Coastguard Worker self.milestones = milestones 101*da0073e9SAndroid Build Coastguard Worker super().__init__(optimizer, last_epoch) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def get_lr(self, step): 104*da0073e9SAndroid Build Coastguard Worker global_step = self.last_epoch 105*da0073e9SAndroid Build Coastguard Worker gamma_power = ( 106*da0073e9SAndroid Build Coastguard Worker [0] 107*da0073e9SAndroid Build Coastguard Worker + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m] 108*da0073e9SAndroid Build Coastguard Worker )[-1] 109*da0073e9SAndroid Build Coastguard Worker return [ 110*da0073e9SAndroid Build Coastguard Worker init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr 111*da0073e9SAndroid Build Coastguard Worker ] 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker optimizer = SGD([torch.rand(1)], lr=1) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 116*da0073e9SAndroid Build Coastguard Worker scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20]) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo( 119*da0073e9SAndroid Build Coastguard Worker "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames" 120*da0073e9SAndroid Build Coastguard Worker ) 121*da0073e9SAndroid Build Coastguard Worker def test_no_cyclic_references(self): 122*da0073e9SAndroid Build Coastguard Worker import gc 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker param = Parameter(torch.empty(10)) 125*da0073e9SAndroid Build Coastguard Worker optim = SGD([param], lr=0.5) 126*da0073e9SAndroid Build Coastguard Worker scheduler = LambdaLR(optim, lambda epoch: 1.0) 127*da0073e9SAndroid Build Coastguard Worker del scheduler 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 130*da0073e9SAndroid Build Coastguard Worker len(gc.get_referrers(optim)) == 0, 131*da0073e9SAndroid Build Coastguard Worker "Optimizer should contain no cyclic references", 132*da0073e9SAndroid Build Coastguard Worker ) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker gc.collect() 135*da0073e9SAndroid Build Coastguard Worker del optim 136*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 137*da0073e9SAndroid Build Coastguard Worker gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__" 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo( 141*da0073e9SAndroid Build Coastguard Worker "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames" 142*da0073e9SAndroid Build Coastguard Worker ) 143*da0073e9SAndroid Build Coastguard Worker def test_no_cyclic_references_in_step(self): 144*da0073e9SAndroid Build Coastguard Worker import gc 145*da0073e9SAndroid Build Coastguard Worker import weakref 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker def run(): 148*da0073e9SAndroid Build Coastguard Worker param = torch.empty(10, requires_grad=True) 149*da0073e9SAndroid Build Coastguard Worker optim = SGD(params=[param], lr=0.5) 150*da0073e9SAndroid Build Coastguard Worker scheduler = LambdaLR(optim, lambda epoch: 1.0) 151*da0073e9SAndroid Build Coastguard Worker param.sum().backward() 152*da0073e9SAndroid Build Coastguard Worker optim.step() 153*da0073e9SAndroid Build Coastguard Worker scheduler.step() 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker return weakref.ref(scheduler) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker # To ensure that there are no reference cycles in scheduler, 158*da0073e9SAndroid Build Coastguard Worker # we need to turn off the garbage collector. Since gc will 159*da0073e9SAndroid Build Coastguard Worker # automatically collect unreachable objects. 160*da0073e9SAndroid Build Coastguard Worker gc.disable() 161*da0073e9SAndroid Build Coastguard Worker ref = run() 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker assert ref() is None 164*da0073e9SAndroid Build Coastguard Worker gc.enable() # restore 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker def test_old_pattern_warning(self): 167*da0073e9SAndroid Build Coastguard Worker epochs = 35 168*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 169*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 170*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 171*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker def old_pattern(): 174*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 175*da0073e9SAndroid Build Coastguard Worker scheduler.step() 176*da0073e9SAndroid Build Coastguard Worker self.opt.step() 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker def test_old_pattern_warning_with_arg(self): 181*da0073e9SAndroid Build Coastguard Worker epochs = 35 182*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 183*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 184*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 185*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def old_pattern2(): 188*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 189*da0073e9SAndroid Build Coastguard Worker scheduler.step() 190*da0073e9SAndroid Build Coastguard Worker self.opt.step() 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker def test_old_pattern_warning_resuming(self): 195*da0073e9SAndroid Build Coastguard Worker epochs = 35 196*da0073e9SAndroid Build Coastguard Worker for i, group in enumerate(self.opt.param_groups): 197*da0073e9SAndroid Build Coastguard Worker group["initial_lr"] = 0.01 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 200*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 201*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) 202*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def old_pattern(): 205*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 206*da0073e9SAndroid Build Coastguard Worker scheduler.step() 207*da0073e9SAndroid Build Coastguard Worker self.opt.step() 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker def test_old_pattern_warning_resuming_with_arg(self): 212*da0073e9SAndroid Build Coastguard Worker epochs = 35 213*da0073e9SAndroid Build Coastguard Worker for i, group in enumerate(self.opt.param_groups): 214*da0073e9SAndroid Build Coastguard Worker group["initial_lr"] = 0.01 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 217*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 218*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) 219*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker def old_pattern2(): 222*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 223*da0073e9SAndroid Build Coastguard Worker scheduler.step() 224*da0073e9SAndroid Build Coastguard Worker self.opt.step() 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker def test_old_pattern_warning_with_overridden_optim_step(self): 229*da0073e9SAndroid Build Coastguard Worker epochs = 35 230*da0073e9SAndroid Build Coastguard Worker for i, group in enumerate(self.opt.param_groups): 231*da0073e9SAndroid Build Coastguard Worker group["initial_lr"] = 0.01 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 234*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 235*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) 236*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker # emulate use-case with optimizer.step overridden 239*da0073e9SAndroid Build Coastguard Worker import types 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker old_step = self.opt.step 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def new_step(o, *args, **kwargs): 244*da0073e9SAndroid Build Coastguard Worker retval = old_step(*args, **kwargs) 245*da0073e9SAndroid Build Coastguard Worker return retval 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker self.opt.step = types.MethodType(new_step, self.opt) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker def old_pattern2(): 250*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 251*da0073e9SAndroid Build Coastguard Worker scheduler.step() 252*da0073e9SAndroid Build Coastguard Worker self.opt.step() 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def test_new_pattern_no_warning(self): 257*da0073e9SAndroid Build Coastguard Worker epochs = 35 258*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 259*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 260*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 261*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 264*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 265*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 266*da0073e9SAndroid Build Coastguard Worker self.opt.step() 267*da0073e9SAndroid Build Coastguard Worker scheduler.step() 268*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker def test_new_pattern_no_warning_with_arg(self): 271*da0073e9SAndroid Build Coastguard Worker epochs = 35 272*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 273*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 274*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 275*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 278*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 279*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 280*da0073e9SAndroid Build Coastguard Worker self.opt.step() 281*da0073e9SAndroid Build Coastguard Worker scheduler.step() 282*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def test_new_pattern_no_warning_with_overridden_optim_step(self): 285*da0073e9SAndroid Build Coastguard Worker epochs = 35 286*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 287*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # allow any warning to be raised 288*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 289*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ws) == 0, "No warning should be raised") 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker # emulate use-case with optimizer.step overridden 292*da0073e9SAndroid Build Coastguard Worker import types 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker old_step = self.opt.step 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def new_step(o, *args, **kwargs): 297*da0073e9SAndroid Build Coastguard Worker retval = old_step(*args, **kwargs) 298*da0073e9SAndroid Build Coastguard Worker return retval 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker self.opt.step = types.MethodType(new_step, self.opt) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker def new_pattern(): 303*da0073e9SAndroid Build Coastguard Worker for e in range(epochs): 304*da0073e9SAndroid Build Coastguard Worker self.opt.step() 305*da0073e9SAndroid Build Coastguard Worker scheduler.step() 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex( 308*da0073e9SAndroid Build Coastguard Worker UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern 309*da0073e9SAndroid Build Coastguard Worker ) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker def _test_lr_is_constant_for_constant_epoch(self, scheduler): 312*da0073e9SAndroid Build Coastguard Worker l = [] 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 315*da0073e9SAndroid Build Coastguard Worker scheduler.optimizer.step() 316*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 317*da0073e9SAndroid Build Coastguard Worker scheduler.step(2) 318*da0073e9SAndroid Build Coastguard Worker self._check_warning_is_epoch_deprecation_warning(w) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker l.append(self.opt.param_groups[0]["lr"]) 321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(min(l), max(l)) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker def test_step_lr_is_constant_for_constant_epoch(self): 324*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, 2) 325*da0073e9SAndroid Build Coastguard Worker self._test_lr_is_constant_for_constant_epoch(scheduler) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker def test_exponential_lr_is_constant_for_constant_epoch(self): 328*da0073e9SAndroid Build Coastguard Worker scheduler = ExponentialLR(self.opt, gamma=0.9) 329*da0073e9SAndroid Build Coastguard Worker self._test_lr_is_constant_for_constant_epoch(scheduler) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker def test_constantlr_is_constant_for_constant_epoch(self): 332*da0073e9SAndroid Build Coastguard Worker scheduler = ConstantLR(self.opt) 333*da0073e9SAndroid Build Coastguard Worker self._test_lr_is_constant_for_constant_epoch(scheduler) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def test_linear_linearlr_is_constant_for_constant_epoch(self): 336*da0073e9SAndroid Build Coastguard Worker scheduler = LinearLR(self.opt) 337*da0073e9SAndroid Build Coastguard Worker self._test_lr_is_constant_for_constant_epoch(scheduler) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker def test_polynomial_lr_is_constant_for_constant_epoch(self): 340*da0073e9SAndroid Build Coastguard Worker scheduler = PolynomialLR(self.opt, power=0.9) 341*da0073e9SAndroid Build Coastguard Worker self._test_lr_is_constant_for_constant_epoch(scheduler) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker def test_step_lr(self): 344*da0073e9SAndroid Build Coastguard Worker # lr = 0.05 if epoch < 3 345*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 30 <= epoch < 6 346*da0073e9SAndroid Build Coastguard Worker # lr = 0.0005 if epoch >= 9 347*da0073e9SAndroid Build Coastguard Worker epochs = 10 348*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 349*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 350*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 351*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker def test_get_last_lr_step_lr(self): 354*da0073e9SAndroid Build Coastguard Worker from torch.nn import Parameter 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker epochs = 10 357*da0073e9SAndroid Build Coastguard Worker optimizer = SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1) 358*da0073e9SAndroid Build Coastguard Worker targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]] 359*da0073e9SAndroid Build Coastguard Worker scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) 360*da0073e9SAndroid Build Coastguard Worker self._test_get_last_lr(scheduler, targets, epochs) 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker def test_get_last_lr_multi_step_lr(self): 363*da0073e9SAndroid Build Coastguard Worker # lr = 0.05 if epoch < 2 364*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 2 <= epoch < 5 365*da0073e9SAndroid Build Coastguard Worker # lr = 0.0005 if 5 <= epoch < 9 366*da0073e9SAndroid Build Coastguard Worker # lr = 0.00005 if 9 <= epoch 367*da0073e9SAndroid Build Coastguard Worker epochs = 10 368*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1 369*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 370*da0073e9SAndroid Build Coastguard Worker scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 371*da0073e9SAndroid Build Coastguard Worker self._test_get_last_lr(scheduler, targets, epochs) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker def test_multi_step_lr(self): 374*da0073e9SAndroid Build Coastguard Worker # lr = 0.05 if epoch < 2 375*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 2 <= epoch < 5 376*da0073e9SAndroid Build Coastguard Worker # lr = 0.0005 if epoch < 9 377*da0073e9SAndroid Build Coastguard Worker # lr = 0.00005 if epoch >= 9 378*da0073e9SAndroid Build Coastguard Worker epochs = 10 379*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 380*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 381*da0073e9SAndroid Build Coastguard Worker scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 382*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker def test_multi_step_lr_with_epoch(self): 385*da0073e9SAndroid Build Coastguard Worker # lr = 0.05 if epoch < 2 386*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 2 <= epoch < 5 387*da0073e9SAndroid Build Coastguard Worker # lr = 0.0005 if epoch < 9 388*da0073e9SAndroid Build Coastguard Worker # lr = 0.00005 if epoch >= 9 389*da0073e9SAndroid Build Coastguard Worker epochs = 10 390*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 391*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 392*da0073e9SAndroid Build Coastguard Worker scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 393*da0073e9SAndroid Build Coastguard Worker self._test_with_epoch(scheduler, targets, epochs) 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker def test_get_last_lr_constantlr(self): 396*da0073e9SAndroid Build Coastguard Worker # lr = 0.025 if epoch < 5 397*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 5 <= epoch 398*da0073e9SAndroid Build Coastguard Worker epochs = 10 399*da0073e9SAndroid Build Coastguard Worker single_targets = [0.025] * 5 + [0.05] * 5 400*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 401*da0073e9SAndroid Build Coastguard Worker scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) 402*da0073e9SAndroid Build Coastguard Worker self._test_get_last_lr(scheduler, targets, epochs) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def test_get_last_lr_linearlr(self): 405*da0073e9SAndroid Build Coastguard Worker # lr = 0.025 if epoch == 0 406*da0073e9SAndroid Build Coastguard Worker # lr = 0.03125 if epoch == 1 407*da0073e9SAndroid Build Coastguard Worker # lr = 0.0375 if epoch == 2 408*da0073e9SAndroid Build Coastguard Worker # lr = 0.04375 if epoch == 3 409*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 4 <= epoch 410*da0073e9SAndroid Build Coastguard Worker epochs = 10 411*da0073e9SAndroid Build Coastguard Worker start_factor = 1.0 / 4 412*da0073e9SAndroid Build Coastguard Worker end_factor = 3.0 / 5 413*da0073e9SAndroid Build Coastguard Worker iters = 4 414*da0073e9SAndroid Build Coastguard Worker interpolation = [ 415*da0073e9SAndroid Build Coastguard Worker start_factor + i * (end_factor - start_factor) / iters for i in range(iters) 416*da0073e9SAndroid Build Coastguard Worker ] 417*da0073e9SAndroid Build Coastguard Worker single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * ( 418*da0073e9SAndroid Build Coastguard Worker epochs - iters 419*da0073e9SAndroid Build Coastguard Worker ) 420*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 421*da0073e9SAndroid Build Coastguard Worker scheduler = LinearLR( 422*da0073e9SAndroid Build Coastguard Worker self.opt, 423*da0073e9SAndroid Build Coastguard Worker start_factor=start_factor, 424*da0073e9SAndroid Build Coastguard Worker end_factor=end_factor, 425*da0073e9SAndroid Build Coastguard Worker total_iters=iters, 426*da0073e9SAndroid Build Coastguard Worker ) 427*da0073e9SAndroid Build Coastguard Worker self._test_get_last_lr(scheduler, targets, epochs) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def test_constantlr(self): 430*da0073e9SAndroid Build Coastguard Worker # lr = 0.025 if epoch < 5 431*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 5 <= epoch 432*da0073e9SAndroid Build Coastguard Worker epochs = 10 433*da0073e9SAndroid Build Coastguard Worker single_targets = [0.025] * 5 + [0.05] * 5 434*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 435*da0073e9SAndroid Build Coastguard Worker scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) 436*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def test_linearlr(self): 439*da0073e9SAndroid Build Coastguard Worker # lr = 0.025 if epoch == 0 440*da0073e9SAndroid Build Coastguard Worker # lr = 0.03125 if epoch == 1 441*da0073e9SAndroid Build Coastguard Worker # lr = 0.0375 if epoch == 2 442*da0073e9SAndroid Build Coastguard Worker # lr = 0.04375 if epoch == 3 443*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 4 <= epoch 444*da0073e9SAndroid Build Coastguard Worker epochs = 10 445*da0073e9SAndroid Build Coastguard Worker start_factor = 1.0 / 2 446*da0073e9SAndroid Build Coastguard Worker iters = 4 447*da0073e9SAndroid Build Coastguard Worker interpolation = [ 448*da0073e9SAndroid Build Coastguard Worker start_factor + i * (1 - start_factor) / iters for i in range(iters) 449*da0073e9SAndroid Build Coastguard Worker ] 450*da0073e9SAndroid Build Coastguard Worker single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) 451*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 452*da0073e9SAndroid Build Coastguard Worker scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 453*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker def test_linearlr_start_factor_limits1(self): 456*da0073e9SAndroid Build Coastguard Worker start_factor = 0.0 457*da0073e9SAndroid Build Coastguard Worker iters = 4 458*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 459*da0073e9SAndroid Build Coastguard Worker LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker def test_linearlr_start_factor_limits2(self): 462*da0073e9SAndroid Build Coastguard Worker start_factor = 1.1 463*da0073e9SAndroid Build Coastguard Worker iters = 4 464*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 465*da0073e9SAndroid Build Coastguard Worker LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker def test_constantlr_with_epoch(self): 468*da0073e9SAndroid Build Coastguard Worker # lr = 0.025 if epoch < 5 469*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 5 <= epoch 470*da0073e9SAndroid Build Coastguard Worker epochs = 10 471*da0073e9SAndroid Build Coastguard Worker single_targets = [0.025] * 5 + [0.05] * 5 472*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 473*da0073e9SAndroid Build Coastguard Worker scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) 474*da0073e9SAndroid Build Coastguard Worker self._test_with_epoch(scheduler, targets, epochs) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker def test_linearlr_with_epoch(self): 477*da0073e9SAndroid Build Coastguard Worker # lr = 0.025 if epoch == 0 478*da0073e9SAndroid Build Coastguard Worker # lr = 0.03125 if epoch == 1 479*da0073e9SAndroid Build Coastguard Worker # lr = 0.0375 if epoch == 2 480*da0073e9SAndroid Build Coastguard Worker # lr = 0.04375 if epoch == 3 481*da0073e9SAndroid Build Coastguard Worker # lr = 0.005 if 4 <= epoch 482*da0073e9SAndroid Build Coastguard Worker epochs = 10 483*da0073e9SAndroid Build Coastguard Worker start_factor = 1.0 / 2 484*da0073e9SAndroid Build Coastguard Worker end_factor = 1.0 485*da0073e9SAndroid Build Coastguard Worker iters = 4 486*da0073e9SAndroid Build Coastguard Worker interpolation = [ 487*da0073e9SAndroid Build Coastguard Worker start_factor + i * (end_factor - start_factor) / iters for i in range(iters) 488*da0073e9SAndroid Build Coastguard Worker ] 489*da0073e9SAndroid Build Coastguard Worker single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) 490*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 491*da0073e9SAndroid Build Coastguard Worker scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 492*da0073e9SAndroid Build Coastguard Worker self._test_with_epoch(scheduler, targets, epochs) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker def test_exp_lr(self): 495*da0073e9SAndroid Build Coastguard Worker epochs = 10 496*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05 * (0.9**x) for x in range(epochs)] 497*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 498*da0073e9SAndroid Build Coastguard Worker scheduler = ExponentialLR(self.opt, gamma=0.9) 499*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker def test_poly_lr(self): 502*da0073e9SAndroid Build Coastguard Worker epochs = 10 503*da0073e9SAndroid Build Coastguard Worker power = 0.9 504*da0073e9SAndroid Build Coastguard Worker total_iters = 5 505*da0073e9SAndroid Build Coastguard Worker single_targets = [ 506*da0073e9SAndroid Build Coastguard Worker (1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters) 507*da0073e9SAndroid Build Coastguard Worker ] + [0.0] * (epochs - total_iters) 508*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 509*da0073e9SAndroid Build Coastguard Worker scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters) 510*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker def test_cos_anneal_lr(self): 513*da0073e9SAndroid Build Coastguard Worker epochs = 10 514*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 515*da0073e9SAndroid Build Coastguard Worker single_targets = [ 516*da0073e9SAndroid Build Coastguard Worker eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 517*da0073e9SAndroid Build Coastguard Worker for x in range(epochs) 518*da0073e9SAndroid Build Coastguard Worker ] 519*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 520*da0073e9SAndroid Build Coastguard Worker scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 521*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker def test_closed_form_step_lr(self): 524*da0073e9SAndroid Build Coastguard Worker scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 525*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 526*da0073e9SAndroid Build Coastguard Worker self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker def test_closed_form_linearlr(self): 529*da0073e9SAndroid Build Coastguard Worker scheduler = LinearLR( 530*da0073e9SAndroid Build Coastguard Worker self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 531*da0073e9SAndroid Build Coastguard Worker ) 532*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler = LinearLR( 533*da0073e9SAndroid Build Coastguard Worker self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 534*da0073e9SAndroid Build Coastguard Worker ) 535*da0073e9SAndroid Build Coastguard Worker self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker def test_closed_form_constantlr(self): 538*da0073e9SAndroid Build Coastguard Worker scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) 539*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) 540*da0073e9SAndroid Build Coastguard Worker self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker def test_closed_form_multi_step_lr(self): 543*da0073e9SAndroid Build Coastguard Worker scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 544*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 545*da0073e9SAndroid Build Coastguard Worker self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker def test_closed_form_exp_lr(self): 548*da0073e9SAndroid Build Coastguard Worker scheduler = ExponentialLR(self.opt, gamma=0.9) 549*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9) 550*da0073e9SAndroid Build Coastguard Worker self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker def test_closed_form_poly_lr(self): 553*da0073e9SAndroid Build Coastguard Worker scheduler = PolynomialLR(self.opt, power=0.9) 554*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler = PolynomialLR(self.opt, power=0.9) 555*da0073e9SAndroid Build Coastguard Worker self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker def test_closed_form_cos_anneal_lr(self): 558*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 559*da0073e9SAndroid Build Coastguard Worker epochs = 20 560*da0073e9SAndroid Build Coastguard Worker T_max = 5 561*da0073e9SAndroid Build Coastguard Worker scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) 562*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler = CosineAnnealingLR( 563*da0073e9SAndroid Build Coastguard Worker self.opt, T_max=T_max, eta_min=eta_min 564*da0073e9SAndroid Build Coastguard Worker ) 565*da0073e9SAndroid Build Coastguard Worker self._test_against_closed_form(scheduler, closed_form_scheduler, epochs) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker def test_cos_anneal_lr_continue(self): 568*da0073e9SAndroid Build Coastguard Worker eta_min = 0.1 569*da0073e9SAndroid Build Coastguard Worker T_max = 5 570*da0073e9SAndroid Build Coastguard Worker scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) 571*da0073e9SAndroid Build Coastguard Worker self.opt.step() 572*da0073e9SAndroid Build Coastguard Worker scheduler.step() 573*da0073e9SAndroid Build Coastguard Worker original_lrs = scheduler._last_lr 574*da0073e9SAndroid Build Coastguard Worker new_scheduler = CosineAnnealingLR( 575*da0073e9SAndroid Build Coastguard Worker self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0 576*da0073e9SAndroid Build Coastguard Worker ) 577*da0073e9SAndroid Build Coastguard Worker new_lrs = new_scheduler._last_lr 578*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau1(self): 581*da0073e9SAndroid Build Coastguard Worker epochs = 10 582*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 583*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 584*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * 20] 585*da0073e9SAndroid Build Coastguard Worker metrics = [10 - i * 0.0167 for i in range(20)] 586*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 587*da0073e9SAndroid Build Coastguard Worker self.opt, 588*da0073e9SAndroid Build Coastguard Worker threshold_mode="abs", 589*da0073e9SAndroid Build Coastguard Worker mode="min", 590*da0073e9SAndroid Build Coastguard Worker threshold=0.01, 591*da0073e9SAndroid Build Coastguard Worker patience=5, 592*da0073e9SAndroid Build Coastguard Worker cooldown=5, 593*da0073e9SAndroid Build Coastguard Worker ) 594*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau2(self): 597*da0073e9SAndroid Build Coastguard Worker epochs = 22 598*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 599*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 600*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2] 601*da0073e9SAndroid Build Coastguard Worker metrics = [10 - i * 0.0165 for i in range(22)] 602*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 603*da0073e9SAndroid Build Coastguard Worker self.opt, 604*da0073e9SAndroid Build Coastguard Worker patience=5, 605*da0073e9SAndroid Build Coastguard Worker cooldown=0, 606*da0073e9SAndroid Build Coastguard Worker threshold_mode="abs", 607*da0073e9SAndroid Build Coastguard Worker mode="min", 608*da0073e9SAndroid Build Coastguard Worker threshold=0.1, 609*da0073e9SAndroid Build Coastguard Worker ) 610*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau3(self): 613*da0073e9SAndroid Build Coastguard Worker epochs = 22 614*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 615*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 616*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4] 617*da0073e9SAndroid Build Coastguard Worker metrics = [-0.8] * 2 + [-0.234] * 20 618*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 619*da0073e9SAndroid Build Coastguard Worker self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" 620*da0073e9SAndroid Build Coastguard Worker ) 621*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau4(self): 624*da0073e9SAndroid Build Coastguard Worker epochs = 20 625*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 626*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 627*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * 20] 628*da0073e9SAndroid Build Coastguard Worker metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 629*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 630*da0073e9SAndroid Build Coastguard Worker self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 631*da0073e9SAndroid Build Coastguard Worker ) 632*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau5(self): 635*da0073e9SAndroid Build Coastguard Worker epochs = 20 636*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 637*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 638*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] 639*da0073e9SAndroid Build Coastguard Worker metrics = [1.5 * (1.005**i) for i in range(20)] 640*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 641*da0073e9SAndroid Build Coastguard Worker self.opt, 642*da0073e9SAndroid Build Coastguard Worker mode="max", 643*da0073e9SAndroid Build Coastguard Worker threshold_mode="rel", 644*da0073e9SAndroid Build Coastguard Worker threshold=0.1, 645*da0073e9SAndroid Build Coastguard Worker patience=5, 646*da0073e9SAndroid Build Coastguard Worker cooldown=5, 647*da0073e9SAndroid Build Coastguard Worker ) 648*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau6(self): 651*da0073e9SAndroid Build Coastguard Worker epochs = 20 652*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 653*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 654*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * 20] 655*da0073e9SAndroid Build Coastguard Worker metrics = [1.5 * (0.85**i) for i in range(20)] 656*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 657*da0073e9SAndroid Build Coastguard Worker self.opt, mode="min", threshold_mode="rel", threshold=0.1 658*da0073e9SAndroid Build Coastguard Worker ) 659*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau7(self): 662*da0073e9SAndroid Build Coastguard Worker epochs = 20 663*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 664*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 665*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] 666*da0073e9SAndroid Build Coastguard Worker metrics = [1] * 7 + [0.6] + [0.5] * 12 667*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 668*da0073e9SAndroid Build Coastguard Worker self.opt, 669*da0073e9SAndroid Build Coastguard Worker mode="min", 670*da0073e9SAndroid Build Coastguard Worker threshold_mode="rel", 671*da0073e9SAndroid Build Coastguard Worker threshold=0.1, 672*da0073e9SAndroid Build Coastguard Worker patience=5, 673*da0073e9SAndroid Build Coastguard Worker cooldown=5, 674*da0073e9SAndroid Build Coastguard Worker ) 675*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau8(self): 678*da0073e9SAndroid Build Coastguard Worker epochs = 20 679*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 680*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 681*da0073e9SAndroid Build Coastguard Worker targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14] 682*da0073e9SAndroid Build Coastguard Worker metrics = [1.5 * (1.005**i) for i in range(20)] 683*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 684*da0073e9SAndroid Build Coastguard Worker self.opt, 685*da0073e9SAndroid Build Coastguard Worker mode="max", 686*da0073e9SAndroid Build Coastguard Worker threshold_mode="rel", 687*da0073e9SAndroid Build Coastguard Worker min_lr=[0.4, 0.3], 688*da0073e9SAndroid Build Coastguard Worker threshold=0.1, 689*da0073e9SAndroid Build Coastguard Worker patience=5, 690*da0073e9SAndroid Build Coastguard Worker cooldown=5, 691*da0073e9SAndroid Build Coastguard Worker ) 692*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau_get_last_lr_before_step(self): 695*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 696*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 697*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau( 698*da0073e9SAndroid Build Coastguard Worker self.opt, 699*da0073e9SAndroid Build Coastguard Worker ) 700*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 701*da0073e9SAndroid Build Coastguard Worker scheduler.get_last_lr(), [0.5 for param_group in self.opt.param_groups] 702*da0073e9SAndroid Build Coastguard Worker ) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker def test_sequentiallr1(self): 705*da0073e9SAndroid Build Coastguard Worker epochs = 19 706*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 707*da0073e9SAndroid Build Coastguard Worker targets = [ 708*da0073e9SAndroid Build Coastguard Worker [0.05, 0.04, 0.032] 709*da0073e9SAndroid Build Coastguard Worker + [0.05 for x in range(4)] 710*da0073e9SAndroid Build Coastguard Worker + [0.05 * 0.1 for x in range(4)] 711*da0073e9SAndroid Build Coastguard Worker + [0.05 * 0.01 for x in range(4)] 712*da0073e9SAndroid Build Coastguard Worker + [0.05 * 0.001 for x in range(4)] 713*da0073e9SAndroid Build Coastguard Worker ] 714*da0073e9SAndroid Build Coastguard Worker milestones = [3] 715*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ExponentialLR(self.opt, gamma=0.8) 716*da0073e9SAndroid Build Coastguard Worker schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4) 717*da0073e9SAndroid Build Coastguard Worker scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 718*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker def test_sequentiallr2(self): 721*da0073e9SAndroid Build Coastguard Worker epochs = 13 722*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 723*da0073e9SAndroid Build Coastguard Worker targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]] 724*da0073e9SAndroid Build Coastguard Worker milestones = [3] 725*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) 726*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 727*da0073e9SAndroid Build Coastguard Worker scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 728*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker def test_sequentiallr3(self): 731*da0073e9SAndroid Build Coastguard Worker epochs = 12 732*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 3 733*da0073e9SAndroid Build Coastguard Worker targets = [ 734*da0073e9SAndroid Build Coastguard Worker [0.005, 0.005, 0.005] 735*da0073e9SAndroid Build Coastguard Worker + [0.05, 0.04, 0.032] 736*da0073e9SAndroid Build Coastguard Worker + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] 737*da0073e9SAndroid Build Coastguard Worker ] 738*da0073e9SAndroid Build Coastguard Worker milestones = [3, 6] 739*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) 740*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.8) 741*da0073e9SAndroid Build Coastguard Worker schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) 742*da0073e9SAndroid Build Coastguard Worker scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 743*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker def test_sequentiallr4(self): 746*da0073e9SAndroid Build Coastguard Worker optimizer = SGD([torch.tensor(0.5)], lr=0.1) 747*da0073e9SAndroid Build Coastguard Worker prev_lr = optimizer.param_groups[0]["lr"] 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker schedulers = [ 750*da0073e9SAndroid Build Coastguard Worker torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1), 751*da0073e9SAndroid Build Coastguard Worker torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1), 752*da0073e9SAndroid Build Coastguard Worker ] 753*da0073e9SAndroid Build Coastguard Worker scheduler = torch.optim.lr_scheduler.SequentialLR( 754*da0073e9SAndroid Build Coastguard Worker optimizer, schedulers, milestones=[10] 755*da0073e9SAndroid Build Coastguard Worker ) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker new_lr = optimizer.param_groups[0]["lr"] 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker # Ensure that multiple schedulers does not affect the initial learning rate 760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prev_lr, new_lr) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker def test_get_last_lr_sequentiallr(self): 763*da0073e9SAndroid Build Coastguard Worker epochs = 12 764*da0073e9SAndroid Build Coastguard Worker milestones = [3, 6] 765*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 3 766*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) 767*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.8) 768*da0073e9SAndroid Build Coastguard Worker schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) 769*da0073e9SAndroid Build Coastguard Worker scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 770*da0073e9SAndroid Build Coastguard Worker constant_lr_target = [0.005] * 3 771*da0073e9SAndroid Build Coastguard Worker exponential_lr_target = [0.05, 0.04, 0.032] 772*da0073e9SAndroid Build Coastguard Worker step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] 773*da0073e9SAndroid Build Coastguard Worker single_targets = constant_lr_target + exponential_lr_target + step_lr_target 774*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * 10 for x in single_targets]] 775*da0073e9SAndroid Build Coastguard Worker self._test_get_last_lr(scheduler, targets, epochs) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker def test_chained_lr2_get_last_lr_before_step(self): 778*da0073e9SAndroid Build Coastguard Worker schedulers = [ 779*da0073e9SAndroid Build Coastguard Worker LinearLR(self.opt, start_factor=0.4, total_iters=3), 780*da0073e9SAndroid Build Coastguard Worker MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1), 781*da0073e9SAndroid Build Coastguard Worker ] 782*da0073e9SAndroid Build Coastguard Worker scheduler = ChainedScheduler(schedulers) 783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker def test_chained_lr1(self): 786*da0073e9SAndroid Build Coastguard Worker epochs = 10 787*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 1 788*da0073e9SAndroid Build Coastguard Worker targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3] 789*da0073e9SAndroid Build Coastguard Worker schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 790*da0073e9SAndroid Build Coastguard Worker scheduler = ChainedScheduler(schedulers) 791*da0073e9SAndroid Build Coastguard Worker self._test([scheduler], targets, epochs) 792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker def test_chained_lr2(self): 795*da0073e9SAndroid Build Coastguard Worker epochs = 10 796*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 1 797*da0073e9SAndroid Build Coastguard Worker targets = [[0.02, 0.03, 0.04] + [0.05] * 9] 798*da0073e9SAndroid Build Coastguard Worker schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) 799*da0073e9SAndroid Build Coastguard Worker scheduler = ChainedScheduler(schedulers) 800*da0073e9SAndroid Build Coastguard Worker self._test([scheduler], targets, epochs) 801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker def test_chained_lr3(self): 804*da0073e9SAndroid Build Coastguard Worker epochs = 10 805*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 806*da0073e9SAndroid Build Coastguard Worker targets = [ 807*da0073e9SAndroid Build Coastguard Worker [0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3 808*da0073e9SAndroid Build Coastguard Worker ] 809*da0073e9SAndroid Build Coastguard Worker schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) 810*da0073e9SAndroid Build Coastguard Worker schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) 811*da0073e9SAndroid Build Coastguard Worker scheduler = ChainedScheduler(schedulers) 812*da0073e9SAndroid Build Coastguard Worker self._test([scheduler], targets, epochs) 813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker def test_chained_lr4(self): 816*da0073e9SAndroid Build Coastguard Worker epochs = 9 817*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 3 818*da0073e9SAndroid Build Coastguard Worker targets = [ 819*da0073e9SAndroid Build Coastguard Worker [0.05 * 0.2 * 0.9**x for x in range(3)] 820*da0073e9SAndroid Build Coastguard Worker + [0.05 * 0.2 * 0.9**3 * 0.1] 821*da0073e9SAndroid Build Coastguard Worker + [0.05 * 0.9**x * 0.1 for x in range(4, 6)] 822*da0073e9SAndroid Build Coastguard Worker + [0.05 * 0.9**x * 0.01 for x in range(6, 9)] 823*da0073e9SAndroid Build Coastguard Worker ] 824*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ExponentialLR(self.opt, gamma=0.9) 825*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4) 826*da0073e9SAndroid Build Coastguard Worker schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3) 827*da0073e9SAndroid Build Coastguard Worker scheduler = ChainedScheduler(schedulers) 828*da0073e9SAndroid Build Coastguard Worker self._test([scheduler], targets, epochs) 829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker def test_chained_lr5(self): 832*da0073e9SAndroid Build Coastguard Worker def poly_lr(lr: float): 833*da0073e9SAndroid Build Coastguard Worker return [ 834*da0073e9SAndroid Build Coastguard Worker (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters) 835*da0073e9SAndroid Build Coastguard Worker ] + [0.0] * (epochs - total_iters) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 838*da0073e9SAndroid Build Coastguard Worker epochs = 10 839*da0073e9SAndroid Build Coastguard Worker power = 0.9 840*da0073e9SAndroid Build Coastguard Worker total_iters = 5 841*da0073e9SAndroid Build Coastguard Worker const_factor = 0.1 842*da0073e9SAndroid Build Coastguard Worker single_targets = [x * const_factor for x in poly_lr(lr=0.05)] 843*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]] 844*da0073e9SAndroid Build Coastguard Worker schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters) 845*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ConstantLR(self.opt, factor=const_factor) 846*da0073e9SAndroid Build Coastguard Worker scheduler = ChainedScheduler(schedulers) 847*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker def test_compound_step_and_multistep_lr(self): 851*da0073e9SAndroid Build Coastguard Worker epochs = 10 852*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 853*da0073e9SAndroid Build Coastguard Worker schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 854*da0073e9SAndroid Build Coastguard Worker schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 855*da0073e9SAndroid Build Coastguard Worker targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]] 856*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker def test_compound_step_and_exp_lr(self): 859*da0073e9SAndroid Build Coastguard Worker epochs = 10 860*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 861*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05 * (0.9**x) for x in range(3)] 862*da0073e9SAndroid Build Coastguard Worker single_targets += [0.005 * (0.9**x) for x in range(3, 6)] 863*da0073e9SAndroid Build Coastguard Worker single_targets += [0.0005 * (0.9**x) for x in range(6, 9)] 864*da0073e9SAndroid Build Coastguard Worker single_targets += [0.00005 * (0.9**x) for x in range(9, 12)] 865*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 866*da0073e9SAndroid Build Coastguard Worker schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 867*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 868*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker def test_compound_exp_and_multistep_lr(self): 871*da0073e9SAndroid Build Coastguard Worker epochs = 10 872*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 873*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05 * (0.9**x) for x in range(2)] 874*da0073e9SAndroid Build Coastguard Worker single_targets += [0.005 * (0.9**x) for x in range(2, 5)] 875*da0073e9SAndroid Build Coastguard Worker single_targets += [0.0005 * (0.9**x) for x in range(5, 9)] 876*da0073e9SAndroid Build Coastguard Worker single_targets += [0.00005 * (0.9**x) for x in range(9, 11)] 877*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 878*da0073e9SAndroid Build Coastguard Worker schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 879*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 880*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker def test_compound_exp_and_linearlr(self): 883*da0073e9SAndroid Build Coastguard Worker epochs = 10 884*da0073e9SAndroid Build Coastguard Worker iters = 4 885*da0073e9SAndroid Build Coastguard Worker start_factor = 0.4 886*da0073e9SAndroid Build Coastguard Worker end_factor = 0.9 887*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 888*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05 * (0.9**x) for x in range(11)] 889*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 890*da0073e9SAndroid Build Coastguard Worker single_targets[i] *= start_factor + i / iters * (end_factor - start_factor) 891*da0073e9SAndroid Build Coastguard Worker for i in range(iters, 11): 892*da0073e9SAndroid Build Coastguard Worker single_targets[i] *= end_factor 893*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 894*da0073e9SAndroid Build Coastguard Worker schedulers[0] = LinearLR( 895*da0073e9SAndroid Build Coastguard Worker self.opt, 896*da0073e9SAndroid Build Coastguard Worker start_factor=start_factor, 897*da0073e9SAndroid Build Coastguard Worker end_factor=end_factor, 898*da0073e9SAndroid Build Coastguard Worker total_iters=iters, 899*da0073e9SAndroid Build Coastguard Worker ) 900*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 901*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker def test_compound_step_and_constantlr(self): 904*da0073e9SAndroid Build Coastguard Worker epochs = 10 905*da0073e9SAndroid Build Coastguard Worker iters = 4 906*da0073e9SAndroid Build Coastguard Worker factor = 0.4 907*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 908*da0073e9SAndroid Build Coastguard Worker single_targets = ( 909*da0073e9SAndroid Build Coastguard Worker [0.05 * 0.4] * 3 910*da0073e9SAndroid Build Coastguard Worker + [0.005 * 0.4] 911*da0073e9SAndroid Build Coastguard Worker + [0.005] * 2 912*da0073e9SAndroid Build Coastguard Worker + [0.0005] * 3 913*da0073e9SAndroid Build Coastguard Worker + [0.00005] * 3 914*da0073e9SAndroid Build Coastguard Worker ) 915*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 916*da0073e9SAndroid Build Coastguard Worker schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 917*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4) 918*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker def test_compound_linearlr_and_multistep_lr(self): 921*da0073e9SAndroid Build Coastguard Worker epochs = 10 922*da0073e9SAndroid Build Coastguard Worker iters = 4 923*da0073e9SAndroid Build Coastguard Worker start_factor = 0.4 924*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 925*da0073e9SAndroid Build Coastguard Worker single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2 926*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 927*da0073e9SAndroid Build Coastguard Worker single_targets[i] *= start_factor + i / iters * (1 - start_factor) 928*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 929*da0073e9SAndroid Build Coastguard Worker schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 930*da0073e9SAndroid Build Coastguard Worker schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 931*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker def test_compound_cosanneal_and_step_lr(self): 934*da0073e9SAndroid Build Coastguard Worker epochs = 10 935*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 936*da0073e9SAndroid Build Coastguard Worker single_targets = [ 937*da0073e9SAndroid Build Coastguard Worker eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 938*da0073e9SAndroid Build Coastguard Worker for x in range(epochs) 939*da0073e9SAndroid Build Coastguard Worker ] 940*da0073e9SAndroid Build Coastguard Worker single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)] 941*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 942*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 943*da0073e9SAndroid Build Coastguard Worker schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 944*da0073e9SAndroid Build Coastguard Worker schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) 945*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker def test_compound_cosanneal_and_multistep_lr(self): 948*da0073e9SAndroid Build Coastguard Worker epochs = 10 949*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 950*da0073e9SAndroid Build Coastguard Worker single_targets = [ 951*da0073e9SAndroid Build Coastguard Worker eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 952*da0073e9SAndroid Build Coastguard Worker for x in range(epochs) 953*da0073e9SAndroid Build Coastguard Worker ] 954*da0073e9SAndroid Build Coastguard Worker multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] 955*da0073e9SAndroid Build Coastguard Worker single_targets = [x * y for x, y in zip(single_targets, multipliers)] 956*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 957*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 958*da0073e9SAndroid Build Coastguard Worker schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 959*da0073e9SAndroid Build Coastguard Worker schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 960*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker def test_compound_cosanneal_and_linearlr(self): 963*da0073e9SAndroid Build Coastguard Worker epochs = 10 964*da0073e9SAndroid Build Coastguard Worker iters = 4 965*da0073e9SAndroid Build Coastguard Worker start_factor = 0.4 966*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 967*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 968*da0073e9SAndroid Build Coastguard Worker single_targets = [ 969*da0073e9SAndroid Build Coastguard Worker eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 970*da0073e9SAndroid Build Coastguard Worker for x in range(epochs) 971*da0073e9SAndroid Build Coastguard Worker ] 972*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 973*da0073e9SAndroid Build Coastguard Worker single_targets[i] *= start_factor + i / iters * (1 - start_factor) 974*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 975*da0073e9SAndroid Build Coastguard Worker schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 976*da0073e9SAndroid Build Coastguard Worker schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 977*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker def test_compound_cosanneal_and_exp_lr(self): 980*da0073e9SAndroid Build Coastguard Worker epochs = 10 981*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 982*da0073e9SAndroid Build Coastguard Worker single_targets = [ 983*da0073e9SAndroid Build Coastguard Worker eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 984*da0073e9SAndroid Build Coastguard Worker for x in range(epochs) 985*da0073e9SAndroid Build Coastguard Worker ] 986*da0073e9SAndroid Build Coastguard Worker multipliers = [0.1**i for i in range(epochs)] 987*da0073e9SAndroid Build Coastguard Worker single_targets = [x * y for x, y in zip(single_targets, multipliers)] 988*da0073e9SAndroid Build Coastguard Worker targets = [single_targets, [x * epochs for x in single_targets]] 989*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 990*da0073e9SAndroid Build Coastguard Worker schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 991*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.1) 992*da0073e9SAndroid Build Coastguard Worker self._test(schedulers, targets, epochs) 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker def test_compound_reduce_lr_on_plateau1(self): 995*da0073e9SAndroid Build Coastguard Worker epochs = 10 996*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 997*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 998*da0073e9SAndroid Build Coastguard Worker single_targets = [0.5] * 20 999*da0073e9SAndroid Build Coastguard Worker multipliers = [0.1 ** (i // 3) for i in range(20)] 1000*da0073e9SAndroid Build Coastguard Worker single_targets = [x * y for x, y in zip(multipliers, single_targets)] 1001*da0073e9SAndroid Build Coastguard Worker targets = [single_targets] 1002*da0073e9SAndroid Build Coastguard Worker targets = targets[1:] # test runs step before checking lr 1003*da0073e9SAndroid Build Coastguard Worker metrics = [10 - i * 0.0167 for i in range(20)] 1004*da0073e9SAndroid Build Coastguard Worker schedulers = [None, None] 1005*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ReduceLROnPlateau( 1006*da0073e9SAndroid Build Coastguard Worker self.opt, 1007*da0073e9SAndroid Build Coastguard Worker threshold_mode="abs", 1008*da0073e9SAndroid Build Coastguard Worker mode="min", 1009*da0073e9SAndroid Build Coastguard Worker threshold=0.01, 1010*da0073e9SAndroid Build Coastguard Worker patience=5, 1011*da0073e9SAndroid Build Coastguard Worker cooldown=5, 1012*da0073e9SAndroid Build Coastguard Worker ) 1013*da0073e9SAndroid Build Coastguard Worker schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) 1014*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker def test_compound_reduce_lr_on_plateau2(self): 1017*da0073e9SAndroid Build Coastguard Worker epochs = 22 1018*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 1019*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 1020*da0073e9SAndroid Build Coastguard Worker single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 1021*da0073e9SAndroid Build Coastguard Worker multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10 1022*da0073e9SAndroid Build Coastguard Worker single_targets = [x * y for x, y in zip(single_targets, multipliers)] 1023*da0073e9SAndroid Build Coastguard Worker targets = [single_targets] 1024*da0073e9SAndroid Build Coastguard Worker targets = targets[1:] # test runs step before checking lr 1025*da0073e9SAndroid Build Coastguard Worker metrics = [10 - i * 0.0165 for i in range(22)] 1026*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 1027*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ReduceLROnPlateau( 1028*da0073e9SAndroid Build Coastguard Worker self.opt, 1029*da0073e9SAndroid Build Coastguard Worker patience=5, 1030*da0073e9SAndroid Build Coastguard Worker cooldown=0, 1031*da0073e9SAndroid Build Coastguard Worker threshold_mode="abs", 1032*da0073e9SAndroid Build Coastguard Worker mode="min", 1033*da0073e9SAndroid Build Coastguard Worker threshold=0.1, 1034*da0073e9SAndroid Build Coastguard Worker ) 1035*da0073e9SAndroid Build Coastguard Worker schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12]) 1036*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker def test_compound_reduce_lr_on_plateau3(self): 1039*da0073e9SAndroid Build Coastguard Worker epochs = 22 1040*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 1041*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 1042*da0073e9SAndroid Build Coastguard Worker single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4 1043*da0073e9SAndroid Build Coastguard Worker multipliers = [0.1**i for i in range(epochs)] 1044*da0073e9SAndroid Build Coastguard Worker single_targets = [x * y for x, y in zip(multipliers, single_targets)] 1045*da0073e9SAndroid Build Coastguard Worker targets = [single_targets] 1046*da0073e9SAndroid Build Coastguard Worker targets = targets[1:] # test runs step before checking lr 1047*da0073e9SAndroid Build Coastguard Worker metrics = [-0.8] * 2 + [-0.234] * 20 1048*da0073e9SAndroid Build Coastguard Worker schedulers = [None, None] 1049*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ReduceLROnPlateau( 1050*da0073e9SAndroid Build Coastguard Worker self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" 1051*da0073e9SAndroid Build Coastguard Worker ) 1052*da0073e9SAndroid Build Coastguard Worker schedulers[1] = ExponentialLR(self.opt, gamma=0.1) 1053*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker def test_compound_reduce_lr_on_plateau4(self): 1056*da0073e9SAndroid Build Coastguard Worker epochs = 20 1057*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 1058*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.05 1059*da0073e9SAndroid Build Coastguard Worker epochs = 10 1060*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 1061*da0073e9SAndroid Build Coastguard Worker single_targets = [ 1062*da0073e9SAndroid Build Coastguard Worker eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 1063*da0073e9SAndroid Build Coastguard Worker for x in range(epochs) 1064*da0073e9SAndroid Build Coastguard Worker ] 1065*da0073e9SAndroid Build Coastguard Worker targets = [single_targets] 1066*da0073e9SAndroid Build Coastguard Worker targets = targets[1:] # test runs step before checking lr 1067*da0073e9SAndroid Build Coastguard Worker metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 1068*da0073e9SAndroid Build Coastguard Worker schedulers = [None, None] 1069*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ReduceLROnPlateau( 1070*da0073e9SAndroid Build Coastguard Worker self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 1071*da0073e9SAndroid Build Coastguard Worker ) 1072*da0073e9SAndroid Build Coastguard Worker schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) 1073*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker def test_compound_reduce_lr_on_plateau5(self): 1076*da0073e9SAndroid Build Coastguard Worker iters = 4 1077*da0073e9SAndroid Build Coastguard Worker start_factor = 0.4 1078*da0073e9SAndroid Build Coastguard Worker epochs = 22 1079*da0073e9SAndroid Build Coastguard Worker for param_group in self.opt.param_groups: 1080*da0073e9SAndroid Build Coastguard Worker param_group["lr"] = 0.5 1081*da0073e9SAndroid Build Coastguard Worker single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 1082*da0073e9SAndroid Build Coastguard Worker multipliers = [1] * 22 1083*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 1084*da0073e9SAndroid Build Coastguard Worker multipliers[i] *= start_factor + i / iters * (1 - start_factor) 1085*da0073e9SAndroid Build Coastguard Worker single_targets = [x * y for x, y in zip(single_targets, multipliers)] 1086*da0073e9SAndroid Build Coastguard Worker targets = [single_targets] 1087*da0073e9SAndroid Build Coastguard Worker targets = targets[1:] # test runs step before checking lr 1088*da0073e9SAndroid Build Coastguard Worker metrics = [10 - i * 0.0165 for i in range(22)] 1089*da0073e9SAndroid Build Coastguard Worker schedulers = [None] * 2 1090*da0073e9SAndroid Build Coastguard Worker schedulers[0] = ReduceLROnPlateau( 1091*da0073e9SAndroid Build Coastguard Worker self.opt, 1092*da0073e9SAndroid Build Coastguard Worker patience=5, 1093*da0073e9SAndroid Build Coastguard Worker cooldown=0, 1094*da0073e9SAndroid Build Coastguard Worker threshold_mode="abs", 1095*da0073e9SAndroid Build Coastguard Worker mode="min", 1096*da0073e9SAndroid Build Coastguard Worker threshold=0.1, 1097*da0073e9SAndroid Build Coastguard Worker ) 1098*da0073e9SAndroid Build Coastguard Worker schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 1099*da0073e9SAndroid Build Coastguard Worker self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_invalid_mode(self): 1102*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1103*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS") 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_triangular_mode_one_lr(self): 1106*da0073e9SAndroid Build Coastguard Worker lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1107*da0073e9SAndroid Build Coastguard Worker momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] 1108*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1109*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1110*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1111*da0073e9SAndroid Build Coastguard Worker self.opt, 1112*da0073e9SAndroid Build Coastguard Worker base_lr=1, 1113*da0073e9SAndroid Build Coastguard Worker max_lr=5, 1114*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1115*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1116*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1117*da0073e9SAndroid Build Coastguard Worker max_momentum=5, 1118*da0073e9SAndroid Build Coastguard Worker mode="triangular", 1119*da0073e9SAndroid Build Coastguard Worker ) 1120*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_triangular_mode_one_lr_no_momentum(self): 1123*da0073e9SAndroid Build Coastguard Worker lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1124*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1125*da0073e9SAndroid Build Coastguard Worker momentum_target = [self.opt.defaults["momentum"]] * len(lr_target) 1126*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1127*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1128*da0073e9SAndroid Build Coastguard Worker self.opt, 1129*da0073e9SAndroid Build Coastguard Worker base_lr=1, 1130*da0073e9SAndroid Build Coastguard Worker max_lr=5, 1131*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1132*da0073e9SAndroid Build Coastguard Worker cycle_momentum=False, 1133*da0073e9SAndroid Build Coastguard Worker mode="triangular", 1134*da0073e9SAndroid Build Coastguard Worker ) 1135*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1136*da0073e9SAndroid Build Coastguard Worker 1137*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_triangular2_mode_one_lr(self): 1138*da0073e9SAndroid Build Coastguard Worker lr_target = [ 1139*da0073e9SAndroid Build Coastguard Worker 1, 1140*da0073e9SAndroid Build Coastguard Worker 2, 1141*da0073e9SAndroid Build Coastguard Worker 3, 1142*da0073e9SAndroid Build Coastguard Worker 4, 1143*da0073e9SAndroid Build Coastguard Worker 5, 1144*da0073e9SAndroid Build Coastguard Worker 4, 1145*da0073e9SAndroid Build Coastguard Worker 3, 1146*da0073e9SAndroid Build Coastguard Worker 2, 1147*da0073e9SAndroid Build Coastguard Worker 1, 1148*da0073e9SAndroid Build Coastguard Worker 1.5, 1149*da0073e9SAndroid Build Coastguard Worker 2.0, 1150*da0073e9SAndroid Build Coastguard Worker 2.5, 1151*da0073e9SAndroid Build Coastguard Worker 3.0, 1152*da0073e9SAndroid Build Coastguard Worker 2.5, 1153*da0073e9SAndroid Build Coastguard Worker 2.0, 1154*da0073e9SAndroid Build Coastguard Worker 1.5, 1155*da0073e9SAndroid Build Coastguard Worker 1, 1156*da0073e9SAndroid Build Coastguard Worker 1.25, 1157*da0073e9SAndroid Build Coastguard Worker 1.50, 1158*da0073e9SAndroid Build Coastguard Worker 1.75, 1159*da0073e9SAndroid Build Coastguard Worker 2.00, 1160*da0073e9SAndroid Build Coastguard Worker 1.75, 1161*da0073e9SAndroid Build Coastguard Worker ] 1162*da0073e9SAndroid Build Coastguard Worker momentum_target = [ 1163*da0073e9SAndroid Build Coastguard Worker 5.0, 1164*da0073e9SAndroid Build Coastguard Worker 4.0, 1165*da0073e9SAndroid Build Coastguard Worker 3.0, 1166*da0073e9SAndroid Build Coastguard Worker 2.0, 1167*da0073e9SAndroid Build Coastguard Worker 1.0, 1168*da0073e9SAndroid Build Coastguard Worker 2.0, 1169*da0073e9SAndroid Build Coastguard Worker 3.0, 1170*da0073e9SAndroid Build Coastguard Worker 4.0, 1171*da0073e9SAndroid Build Coastguard Worker 5.0, 1172*da0073e9SAndroid Build Coastguard Worker 4.5, 1173*da0073e9SAndroid Build Coastguard Worker 4.0, 1174*da0073e9SAndroid Build Coastguard Worker 3.5, 1175*da0073e9SAndroid Build Coastguard Worker 3.0, 1176*da0073e9SAndroid Build Coastguard Worker 3.5, 1177*da0073e9SAndroid Build Coastguard Worker 4.0, 1178*da0073e9SAndroid Build Coastguard Worker 4.5, 1179*da0073e9SAndroid Build Coastguard Worker 5.0, 1180*da0073e9SAndroid Build Coastguard Worker 4.75, 1181*da0073e9SAndroid Build Coastguard Worker 4.5, 1182*da0073e9SAndroid Build Coastguard Worker 4.25, 1183*da0073e9SAndroid Build Coastguard Worker 4.0, 1184*da0073e9SAndroid Build Coastguard Worker 4.25, 1185*da0073e9SAndroid Build Coastguard Worker ] 1186*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1187*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1188*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1189*da0073e9SAndroid Build Coastguard Worker self.opt, 1190*da0073e9SAndroid Build Coastguard Worker base_lr=1, 1191*da0073e9SAndroid Build Coastguard Worker max_lr=5, 1192*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1193*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1194*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1195*da0073e9SAndroid Build Coastguard Worker max_momentum=5, 1196*da0073e9SAndroid Build Coastguard Worker mode="triangular2", 1197*da0073e9SAndroid Build Coastguard Worker ) 1198*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_exp_range_mode_one_lr(self): 1201*da0073e9SAndroid Build Coastguard Worker base_lr, max_lr = 1, 5 1202*da0073e9SAndroid Build Coastguard Worker diff_lr = max_lr - base_lr 1203*da0073e9SAndroid Build Coastguard Worker gamma = 0.9 1204*da0073e9SAndroid Build Coastguard Worker xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] 1205*da0073e9SAndroid Build Coastguard Worker lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] 1206*da0073e9SAndroid Build Coastguard Worker momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] 1207*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1208*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1209*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1210*da0073e9SAndroid Build Coastguard Worker self.opt, 1211*da0073e9SAndroid Build Coastguard Worker base_lr=base_lr, 1212*da0073e9SAndroid Build Coastguard Worker max_lr=max_lr, 1213*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1214*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1215*da0073e9SAndroid Build Coastguard Worker base_momentum=base_lr, 1216*da0073e9SAndroid Build Coastguard Worker max_momentum=max_lr, 1217*da0073e9SAndroid Build Coastguard Worker mode="exp_range", 1218*da0073e9SAndroid Build Coastguard Worker gamma=gamma, 1219*da0073e9SAndroid Build Coastguard Worker ) 1220*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1221*da0073e9SAndroid Build Coastguard Worker 1222*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_triangular_mode(self): 1223*da0073e9SAndroid Build Coastguard Worker lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1224*da0073e9SAndroid Build Coastguard Worker lr_target_2 = [x + 1 for x in lr_target_1] 1225*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target_1, lr_target_2] 1226*da0073e9SAndroid Build Coastguard Worker momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] 1227*da0073e9SAndroid Build Coastguard Worker momentum_target_2 = [x + 1 for x in momentum_target_1] 1228*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target_1, momentum_target_2] 1229*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1230*da0073e9SAndroid Build Coastguard Worker self.opt, 1231*da0073e9SAndroid Build Coastguard Worker base_lr=[1, 2], 1232*da0073e9SAndroid Build Coastguard Worker max_lr=[5, 6], 1233*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1234*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1235*da0073e9SAndroid Build Coastguard Worker base_momentum=[1, 2], 1236*da0073e9SAndroid Build Coastguard Worker max_momentum=[5, 6], 1237*da0073e9SAndroid Build Coastguard Worker mode="triangular", 1238*da0073e9SAndroid Build Coastguard Worker ) 1239*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_triangular2_mode(self): 1242*da0073e9SAndroid Build Coastguard Worker lr_target_1 = [ 1243*da0073e9SAndroid Build Coastguard Worker 1, 1244*da0073e9SAndroid Build Coastguard Worker 2, 1245*da0073e9SAndroid Build Coastguard Worker 3, 1246*da0073e9SAndroid Build Coastguard Worker 4, 1247*da0073e9SAndroid Build Coastguard Worker 5, 1248*da0073e9SAndroid Build Coastguard Worker 4, 1249*da0073e9SAndroid Build Coastguard Worker 3, 1250*da0073e9SAndroid Build Coastguard Worker 2, 1251*da0073e9SAndroid Build Coastguard Worker 1, 1252*da0073e9SAndroid Build Coastguard Worker 1.5, 1253*da0073e9SAndroid Build Coastguard Worker 2.0, 1254*da0073e9SAndroid Build Coastguard Worker 2.5, 1255*da0073e9SAndroid Build Coastguard Worker 3.0, 1256*da0073e9SAndroid Build Coastguard Worker 2.5, 1257*da0073e9SAndroid Build Coastguard Worker 2.0, 1258*da0073e9SAndroid Build Coastguard Worker 1.5, 1259*da0073e9SAndroid Build Coastguard Worker 1, 1260*da0073e9SAndroid Build Coastguard Worker 1.25, 1261*da0073e9SAndroid Build Coastguard Worker 1.50, 1262*da0073e9SAndroid Build Coastguard Worker 1.75, 1263*da0073e9SAndroid Build Coastguard Worker 2.00, 1264*da0073e9SAndroid Build Coastguard Worker 1.75, 1265*da0073e9SAndroid Build Coastguard Worker ] 1266*da0073e9SAndroid Build Coastguard Worker lr_target_2 = [x + 2 for x in lr_target_1] 1267*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target_1, lr_target_2] 1268*da0073e9SAndroid Build Coastguard Worker momentum_target_1 = [ 1269*da0073e9SAndroid Build Coastguard Worker 5.0, 1270*da0073e9SAndroid Build Coastguard Worker 4.0, 1271*da0073e9SAndroid Build Coastguard Worker 3.0, 1272*da0073e9SAndroid Build Coastguard Worker 2.0, 1273*da0073e9SAndroid Build Coastguard Worker 1.0, 1274*da0073e9SAndroid Build Coastguard Worker 2.0, 1275*da0073e9SAndroid Build Coastguard Worker 3.0, 1276*da0073e9SAndroid Build Coastguard Worker 4.0, 1277*da0073e9SAndroid Build Coastguard Worker 5.0, 1278*da0073e9SAndroid Build Coastguard Worker 4.5, 1279*da0073e9SAndroid Build Coastguard Worker 4.0, 1280*da0073e9SAndroid Build Coastguard Worker 3.5, 1281*da0073e9SAndroid Build Coastguard Worker 3.0, 1282*da0073e9SAndroid Build Coastguard Worker 3.5, 1283*da0073e9SAndroid Build Coastguard Worker 4.0, 1284*da0073e9SAndroid Build Coastguard Worker 4.5, 1285*da0073e9SAndroid Build Coastguard Worker 5.0, 1286*da0073e9SAndroid Build Coastguard Worker 4.75, 1287*da0073e9SAndroid Build Coastguard Worker 4.5, 1288*da0073e9SAndroid Build Coastguard Worker 4.25, 1289*da0073e9SAndroid Build Coastguard Worker 4.0, 1290*da0073e9SAndroid Build Coastguard Worker 4.25, 1291*da0073e9SAndroid Build Coastguard Worker ] 1292*da0073e9SAndroid Build Coastguard Worker momentum_target_2 = [x + 2 for x in momentum_target_1] 1293*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target_1, momentum_target_2] 1294*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1295*da0073e9SAndroid Build Coastguard Worker self.opt, 1296*da0073e9SAndroid Build Coastguard Worker base_lr=[1, 3], 1297*da0073e9SAndroid Build Coastguard Worker max_lr=[5, 7], 1298*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1299*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1300*da0073e9SAndroid Build Coastguard Worker base_momentum=[1, 3], 1301*da0073e9SAndroid Build Coastguard Worker max_momentum=[5, 7], 1302*da0073e9SAndroid Build Coastguard Worker mode="triangular2", 1303*da0073e9SAndroid Build Coastguard Worker ) 1304*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_exp_range_mode(self): 1307*da0073e9SAndroid Build Coastguard Worker base_lr_1, max_lr_1 = 1, 5 1308*da0073e9SAndroid Build Coastguard Worker base_lr_2, max_lr_2 = 5, 12 1309*da0073e9SAndroid Build Coastguard Worker 1310*da0073e9SAndroid Build Coastguard Worker diff_lr_1 = max_lr_1 - base_lr_1 1311*da0073e9SAndroid Build Coastguard Worker diff_lr_2 = max_lr_2 - base_lr_2 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker gamma = 0.9 1314*da0073e9SAndroid Build Coastguard Worker xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] 1315*da0073e9SAndroid Build Coastguard Worker lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] 1316*da0073e9SAndroid Build Coastguard Worker lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] 1317*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target_1, lr_target_2] 1318*da0073e9SAndroid Build Coastguard Worker momentum_target_1 = [ 1319*da0073e9SAndroid Build Coastguard Worker max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs) 1320*da0073e9SAndroid Build Coastguard Worker ] 1321*da0073e9SAndroid Build Coastguard Worker momentum_target_2 = [ 1322*da0073e9SAndroid Build Coastguard Worker max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs) 1323*da0073e9SAndroid Build Coastguard Worker ] 1324*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target_1, momentum_target_2] 1325*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1326*da0073e9SAndroid Build Coastguard Worker self.opt, 1327*da0073e9SAndroid Build Coastguard Worker base_lr=[base_lr_1, base_lr_2], 1328*da0073e9SAndroid Build Coastguard Worker max_lr=[max_lr_1, max_lr_2], 1329*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1330*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1331*da0073e9SAndroid Build Coastguard Worker base_momentum=[base_lr_1, base_lr_2], 1332*da0073e9SAndroid Build Coastguard Worker max_momentum=[max_lr_1, max_lr_2], 1333*da0073e9SAndroid Build Coastguard Worker mode="exp_range", 1334*da0073e9SAndroid Build Coastguard Worker gamma=gamma, 1335*da0073e9SAndroid Build Coastguard Worker ) 1336*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_triangular_mode_step_size_up_down(self): 1339*da0073e9SAndroid Build Coastguard Worker lr_target = [ 1340*da0073e9SAndroid Build Coastguard Worker 1.0, 1341*da0073e9SAndroid Build Coastguard Worker 2.0, 1342*da0073e9SAndroid Build Coastguard Worker 3.0, 1343*da0073e9SAndroid Build Coastguard Worker 4.0, 1344*da0073e9SAndroid Build Coastguard Worker 5.0, 1345*da0073e9SAndroid Build Coastguard Worker 13.0 / 3, 1346*da0073e9SAndroid Build Coastguard Worker 11.0 / 3, 1347*da0073e9SAndroid Build Coastguard Worker 9.0 / 3, 1348*da0073e9SAndroid Build Coastguard Worker 7.0 / 3, 1349*da0073e9SAndroid Build Coastguard Worker 5.0 / 3, 1350*da0073e9SAndroid Build Coastguard Worker 1.0, 1351*da0073e9SAndroid Build Coastguard Worker ] 1352*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1353*da0073e9SAndroid Build Coastguard Worker momentum_target = [ 1354*da0073e9SAndroid Build Coastguard Worker 5.0, 1355*da0073e9SAndroid Build Coastguard Worker 4.0, 1356*da0073e9SAndroid Build Coastguard Worker 3.0, 1357*da0073e9SAndroid Build Coastguard Worker 2.0, 1358*da0073e9SAndroid Build Coastguard Worker 1.0, 1359*da0073e9SAndroid Build Coastguard Worker 5.0 / 3, 1360*da0073e9SAndroid Build Coastguard Worker 7.0 / 3, 1361*da0073e9SAndroid Build Coastguard Worker 3.0, 1362*da0073e9SAndroid Build Coastguard Worker 11.0 / 3, 1363*da0073e9SAndroid Build Coastguard Worker 13.0 / 3, 1364*da0073e9SAndroid Build Coastguard Worker 5.0, 1365*da0073e9SAndroid Build Coastguard Worker ] 1366*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1369*da0073e9SAndroid Build Coastguard Worker self.opt, 1370*da0073e9SAndroid Build Coastguard Worker base_lr=1, 1371*da0073e9SAndroid Build Coastguard Worker max_lr=5, 1372*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1373*da0073e9SAndroid Build Coastguard Worker step_size_down=6, 1374*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1375*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1376*da0073e9SAndroid Build Coastguard Worker max_momentum=5, 1377*da0073e9SAndroid Build Coastguard Worker mode="triangular", 1378*da0073e9SAndroid Build Coastguard Worker ) 1379*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_triangular2_mode_step_size_up_down(self): 1382*da0073e9SAndroid Build Coastguard Worker lr_base_target = [ 1383*da0073e9SAndroid Build Coastguard Worker 1.0, 1384*da0073e9SAndroid Build Coastguard Worker 3.0, 1385*da0073e9SAndroid Build Coastguard Worker 5.0, 1386*da0073e9SAndroid Build Coastguard Worker 13.0 / 3, 1387*da0073e9SAndroid Build Coastguard Worker 11.0 / 3, 1388*da0073e9SAndroid Build Coastguard Worker 9.0 / 3, 1389*da0073e9SAndroid Build Coastguard Worker 7.0 / 3, 1390*da0073e9SAndroid Build Coastguard Worker 5.0 / 3, 1391*da0073e9SAndroid Build Coastguard Worker 1.0, 1392*da0073e9SAndroid Build Coastguard Worker 2.0, 1393*da0073e9SAndroid Build Coastguard Worker 3.0, 1394*da0073e9SAndroid Build Coastguard Worker 8.0 / 3, 1395*da0073e9SAndroid Build Coastguard Worker 7.0 / 3, 1396*da0073e9SAndroid Build Coastguard Worker 6.0 / 3, 1397*da0073e9SAndroid Build Coastguard Worker 5.0 / 3, 1398*da0073e9SAndroid Build Coastguard Worker 4.0 / 3, 1399*da0073e9SAndroid Build Coastguard Worker 1.0, 1400*da0073e9SAndroid Build Coastguard Worker 3.0 / 2, 1401*da0073e9SAndroid Build Coastguard Worker 2.0, 1402*da0073e9SAndroid Build Coastguard Worker 11.0 / 6, 1403*da0073e9SAndroid Build Coastguard Worker 10.0 / 6, 1404*da0073e9SAndroid Build Coastguard Worker 9.0 / 6, 1405*da0073e9SAndroid Build Coastguard Worker 8.0 / 6, 1406*da0073e9SAndroid Build Coastguard Worker 7.0 / 6, 1407*da0073e9SAndroid Build Coastguard Worker ] 1408*da0073e9SAndroid Build Coastguard Worker momentum_base_target = [ 1409*da0073e9SAndroid Build Coastguard Worker 5.0, 1410*da0073e9SAndroid Build Coastguard Worker 3.0, 1411*da0073e9SAndroid Build Coastguard Worker 1.0, 1412*da0073e9SAndroid Build Coastguard Worker 5.0 / 3, 1413*da0073e9SAndroid Build Coastguard Worker 7.0 / 3, 1414*da0073e9SAndroid Build Coastguard Worker 3.0, 1415*da0073e9SAndroid Build Coastguard Worker 11.0 / 3, 1416*da0073e9SAndroid Build Coastguard Worker 13.0 / 3, 1417*da0073e9SAndroid Build Coastguard Worker 5.0, 1418*da0073e9SAndroid Build Coastguard Worker 4.0, 1419*da0073e9SAndroid Build Coastguard Worker 3.0, 1420*da0073e9SAndroid Build Coastguard Worker 10.0 / 3, 1421*da0073e9SAndroid Build Coastguard Worker 11.0 / 3, 1422*da0073e9SAndroid Build Coastguard Worker 4.0, 1423*da0073e9SAndroid Build Coastguard Worker 13.0 / 3, 1424*da0073e9SAndroid Build Coastguard Worker 14.0 / 3, 1425*da0073e9SAndroid Build Coastguard Worker 5.0, 1426*da0073e9SAndroid Build Coastguard Worker 4.5, 1427*da0073e9SAndroid Build Coastguard Worker 4.0, 1428*da0073e9SAndroid Build Coastguard Worker 25.0 / 6, 1429*da0073e9SAndroid Build Coastguard Worker 13.0 / 3, 1430*da0073e9SAndroid Build Coastguard Worker 4.5, 1431*da0073e9SAndroid Build Coastguard Worker 14.0 / 3, 1432*da0073e9SAndroid Build Coastguard Worker 29.0 / 6, 1433*da0073e9SAndroid Build Coastguard Worker ] 1434*da0073e9SAndroid Build Coastguard Worker deltas = [2 * i for i in range(0, 2)] 1435*da0073e9SAndroid Build Coastguard Worker base_lrs = [1 + delta for delta in deltas] 1436*da0073e9SAndroid Build Coastguard Worker max_lrs = [5 + delta for delta in deltas] 1437*da0073e9SAndroid Build Coastguard Worker lr_targets = [[x + delta for x in lr_base_target] for delta in deltas] 1438*da0073e9SAndroid Build Coastguard Worker momentum_targets = [ 1439*da0073e9SAndroid Build Coastguard Worker [x + delta for x in momentum_base_target] for delta in deltas 1440*da0073e9SAndroid Build Coastguard Worker ] 1441*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1442*da0073e9SAndroid Build Coastguard Worker self.opt, 1443*da0073e9SAndroid Build Coastguard Worker base_lr=base_lrs, 1444*da0073e9SAndroid Build Coastguard Worker max_lr=max_lrs, 1445*da0073e9SAndroid Build Coastguard Worker step_size_up=2, 1446*da0073e9SAndroid Build Coastguard Worker step_size_down=6, 1447*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1448*da0073e9SAndroid Build Coastguard Worker base_momentum=base_lrs, 1449*da0073e9SAndroid Build Coastguard Worker max_momentum=max_lrs, 1450*da0073e9SAndroid Build Coastguard Worker mode="triangular2", 1451*da0073e9SAndroid Build Coastguard Worker ) 1452*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr( 1453*da0073e9SAndroid Build Coastguard Worker scheduler, lr_targets, momentum_targets, len(lr_base_target) 1454*da0073e9SAndroid Build Coastguard Worker ) 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_exp_range_mode_step_size_up_down(self): 1457*da0073e9SAndroid Build Coastguard Worker base_lr, max_lr = 1, 5 1458*da0073e9SAndroid Build Coastguard Worker diff_lr = max_lr - base_lr 1459*da0073e9SAndroid Build Coastguard Worker gamma = 0.9 1460*da0073e9SAndroid Build Coastguard Worker xs = [ 1461*da0073e9SAndroid Build Coastguard Worker 0.0, 1462*da0073e9SAndroid Build Coastguard Worker 0.5, 1463*da0073e9SAndroid Build Coastguard Worker 1.0, 1464*da0073e9SAndroid Build Coastguard Worker 5.0 / 6, 1465*da0073e9SAndroid Build Coastguard Worker 4.0 / 6, 1466*da0073e9SAndroid Build Coastguard Worker 3.0 / 6, 1467*da0073e9SAndroid Build Coastguard Worker 2.0 / 6, 1468*da0073e9SAndroid Build Coastguard Worker 1.0 / 6, 1469*da0073e9SAndroid Build Coastguard Worker 0.0, 1470*da0073e9SAndroid Build Coastguard Worker 0.5, 1471*da0073e9SAndroid Build Coastguard Worker 1.0, 1472*da0073e9SAndroid Build Coastguard Worker 5.0 / 6, 1473*da0073e9SAndroid Build Coastguard Worker 4.0 / 6, 1474*da0073e9SAndroid Build Coastguard Worker ] 1475*da0073e9SAndroid Build Coastguard Worker lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] 1476*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1477*da0073e9SAndroid Build Coastguard Worker momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] 1478*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1479*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1480*da0073e9SAndroid Build Coastguard Worker self.opt, 1481*da0073e9SAndroid Build Coastguard Worker base_lr=base_lr, 1482*da0073e9SAndroid Build Coastguard Worker max_lr=max_lr, 1483*da0073e9SAndroid Build Coastguard Worker step_size_up=2, 1484*da0073e9SAndroid Build Coastguard Worker step_size_down=6, 1485*da0073e9SAndroid Build Coastguard Worker cycle_momentum=True, 1486*da0073e9SAndroid Build Coastguard Worker base_momentum=base_lr, 1487*da0073e9SAndroid Build Coastguard Worker max_momentum=max_lr, 1488*da0073e9SAndroid Build Coastguard Worker mode="exp_range", 1489*da0073e9SAndroid Build Coastguard Worker gamma=gamma, 1490*da0073e9SAndroid Build Coastguard Worker ) 1491*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_with_momentumless_optimizer(self): 1494*da0073e9SAndroid Build Coastguard Worker # Note [Temporarily set optimizer to Adam] 1495*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1496*da0073e9SAndroid Build Coastguard Worker # The TestLRScheduler object carries around an SGD optimizer to avoid having to 1497*da0073e9SAndroid Build Coastguard Worker # instantiate one for every test. This gets in the way for our very specific case 1498*da0073e9SAndroid Build Coastguard Worker # in which we need to use Adam (or really any optimizer that doesn't use momentum) 1499*da0073e9SAndroid Build Coastguard Worker # in order to test that the momentum bug in CyclicLR is fixed (the bug is described 1500*da0073e9SAndroid Build Coastguard Worker # in more detail in https://github.com/pytorch/pytorch/issues/19003 ). 1501*da0073e9SAndroid Build Coastguard Worker old_opt = self.opt 1502*da0073e9SAndroid Build Coastguard Worker self.opt = Adam( 1503*da0073e9SAndroid Build Coastguard Worker [ 1504*da0073e9SAndroid Build Coastguard Worker {"params": self.net.conv1.parameters()}, 1505*da0073e9SAndroid Build Coastguard Worker {"params": self.net.conv2.parameters(), "lr": 0.5}, 1506*da0073e9SAndroid Build Coastguard Worker ], 1507*da0073e9SAndroid Build Coastguard Worker lr=0.05, 1508*da0073e9SAndroid Build Coastguard Worker ) 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1511*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1512*da0073e9SAndroid Build Coastguard Worker momentum_target = [None] * len(lr_target) 1513*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1514*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1515*da0073e9SAndroid Build Coastguard Worker self.opt, 1516*da0073e9SAndroid Build Coastguard Worker base_lr=1, 1517*da0073e9SAndroid Build Coastguard Worker max_lr=5, 1518*da0073e9SAndroid Build Coastguard Worker step_size_up=4, 1519*da0073e9SAndroid Build Coastguard Worker cycle_momentum=False, 1520*da0073e9SAndroid Build Coastguard Worker mode="triangular", 1521*da0073e9SAndroid Build Coastguard Worker ) 1522*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker self.opt = old_opt # set optimizer back to SGD 1525*da0073e9SAndroid Build Coastguard Worker 1526*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self): 1527*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1528*da0073e9SAndroid Build Coastguard Worker rprop_opt = Rprop(self.net.parameters()) 1529*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True) 1530*da0073e9SAndroid Build Coastguard Worker 1531*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self): 1532*da0073e9SAndroid Build Coastguard Worker adam_opt = Adam(self.net.parameters()) 1533*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True) 1534*da0073e9SAndroid Build Coastguard Worker 1535*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_removed_after_out_of_scope(self): 1536*da0073e9SAndroid Build Coastguard Worker import gc 1537*da0073e9SAndroid Build Coastguard Worker import weakref 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker gc.disable() 1540*da0073e9SAndroid Build Coastguard Worker 1541*da0073e9SAndroid Build Coastguard Worker def test(): 1542*da0073e9SAndroid Build Coastguard Worker adam_opt = Adam(self.net.parameters()) 1543*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False) 1544*da0073e9SAndroid Build Coastguard Worker return weakref.ref(scheduler) 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker ref = test() 1547*da0073e9SAndroid Build Coastguard Worker assert ref() is None 1548*da0073e9SAndroid Build Coastguard Worker gc.enable() 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_state_dict_picklable(self): 1551*da0073e9SAndroid Build Coastguard Worker adam_opt = Adam(self.net.parameters()) 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker # Case 1: Built-in mode 1554*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False) 1555*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType) 1556*da0073e9SAndroid Build Coastguard Worker state = scheduler.state_dict() 1557*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("_scale_fn_ref", state) 1558*da0073e9SAndroid Build Coastguard Worker self.assertIs(state["_scale_fn_custom"], None) 1559*da0073e9SAndroid Build Coastguard Worker pickle.dumps(state) 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker # Case 2: Custom `scale_fn`, a function object 1562*da0073e9SAndroid Build Coastguard Worker def scale_fn(_): 1563*da0073e9SAndroid Build Coastguard Worker return 0.5 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1566*da0073e9SAndroid Build Coastguard Worker adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1567*da0073e9SAndroid Build Coastguard Worker ) 1568*da0073e9SAndroid Build Coastguard Worker state = scheduler.state_dict() 1569*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("_scale_fn_ref", state) 1570*da0073e9SAndroid Build Coastguard Worker self.assertIs(state["_scale_fn_custom"], None) 1571*da0073e9SAndroid Build Coastguard Worker pickle.dumps(state) 1572*da0073e9SAndroid Build Coastguard Worker 1573*da0073e9SAndroid Build Coastguard Worker # Case 3: Custom `scale_fn`, a callable class 1574*da0073e9SAndroid Build Coastguard Worker class ScaleFn: 1575*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1576*da0073e9SAndroid Build Coastguard Worker self.x = 0.5 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker def __call__(self, _): 1579*da0073e9SAndroid Build Coastguard Worker return self.x 1580*da0073e9SAndroid Build Coastguard Worker 1581*da0073e9SAndroid Build Coastguard Worker scale_fn = ScaleFn() 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1584*da0073e9SAndroid Build Coastguard Worker adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1585*da0073e9SAndroid Build Coastguard Worker ) 1586*da0073e9SAndroid Build Coastguard Worker state = scheduler.state_dict() 1587*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("_scale_fn_ref", state) 1588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__) 1589*da0073e9SAndroid Build Coastguard Worker pickle.dumps(state) 1590*da0073e9SAndroid Build Coastguard Worker 1591*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_scale_fn_restored_from_state_dict(self): 1592*da0073e9SAndroid Build Coastguard Worker adam_opt = Adam(self.net.parameters()) 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker # Case 1: Built-in mode 1595*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1596*da0073e9SAndroid Build Coastguard Worker adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, mode="triangular2" 1597*da0073e9SAndroid Build Coastguard Worker ) 1598*da0073e9SAndroid Build Coastguard Worker restored_scheduler = CyclicLR( 1599*da0073e9SAndroid Build Coastguard Worker adam_opt, base_lr=1, max_lr=5, cycle_momentum=False 1600*da0073e9SAndroid Build Coastguard Worker ) 1601*da0073e9SAndroid Build Coastguard Worker restored_scheduler.load_state_dict(scheduler.state_dict()) 1602*da0073e9SAndroid Build Coastguard Worker self.assertTrue(restored_scheduler.mode == scheduler.mode == "triangular2") 1603*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(restored_scheduler._scale_fn_ref) and self.assertIsNotNone( 1604*da0073e9SAndroid Build Coastguard Worker scheduler._scale_fn_ref 1605*da0073e9SAndroid Build Coastguard Worker ) 1606*da0073e9SAndroid Build Coastguard Worker self.assertIs(restored_scheduler._scale_fn_custom, None) 1607*da0073e9SAndroid Build Coastguard Worker self.assertIs(scheduler._scale_fn_custom, None) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker # Case 2: Custom `scale_fn` 1610*da0073e9SAndroid Build Coastguard Worker def scale_fn(_): 1611*da0073e9SAndroid Build Coastguard Worker return 0.5 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker scheduler = CyclicLR( 1614*da0073e9SAndroid Build Coastguard Worker adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1615*da0073e9SAndroid Build Coastguard Worker ) 1616*da0073e9SAndroid Build Coastguard Worker restored_scheduler = CyclicLR( 1617*da0073e9SAndroid Build Coastguard Worker adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1618*da0073e9SAndroid Build Coastguard Worker ) 1619*da0073e9SAndroid Build Coastguard Worker restored_scheduler.load_state_dict(scheduler.state_dict()) 1620*da0073e9SAndroid Build Coastguard Worker self.assertIs(scheduler._scale_fn_custom, scale_fn) 1621*da0073e9SAndroid Build Coastguard Worker self.assertIs(restored_scheduler._scale_fn_custom, scale_fn) 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker def test_onecycle_lr_invalid_anneal_strategy(self): 1624*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1625*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR( 1626*da0073e9SAndroid Build Coastguard Worker self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS" 1627*da0073e9SAndroid Build Coastguard Worker ) 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker def test_onecycle_lr_invalid_pct_start(self): 1630*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1631*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1) 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker def test_onecycle_lr_cannot_calculate_total_steps(self): 1634*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1635*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR(self.opt, max_lr=1e-3) 1636*da0073e9SAndroid Build Coastguard Worker 1637*da0073e9SAndroid Build Coastguard Worker def test_onecycle_lr_linear_annealing(self): 1638*da0073e9SAndroid Build Coastguard Worker lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] 1639*da0073e9SAndroid Build Coastguard Worker momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] 1640*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1641*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1642*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR( 1643*da0073e9SAndroid Build Coastguard Worker self.opt, 1644*da0073e9SAndroid Build Coastguard Worker max_lr=25, 1645*da0073e9SAndroid Build Coastguard Worker final_div_factor=2, 1646*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1647*da0073e9SAndroid Build Coastguard Worker max_momentum=22, 1648*da0073e9SAndroid Build Coastguard Worker total_steps=10, 1649*da0073e9SAndroid Build Coastguard Worker anneal_strategy="linear", 1650*da0073e9SAndroid Build Coastguard Worker ) 1651*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker def test_onecycle_lr_linear_annealing_three_phases(self): 1654*da0073e9SAndroid Build Coastguard Worker lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25] 1655*da0073e9SAndroid Build Coastguard Worker momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22] 1656*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1657*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1658*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR( 1659*da0073e9SAndroid Build Coastguard Worker self.opt, 1660*da0073e9SAndroid Build Coastguard Worker max_lr=25, 1661*da0073e9SAndroid Build Coastguard Worker div_factor=25, 1662*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1663*da0073e9SAndroid Build Coastguard Worker max_momentum=22, 1664*da0073e9SAndroid Build Coastguard Worker total_steps=10, 1665*da0073e9SAndroid Build Coastguard Worker anneal_strategy="linear", 1666*da0073e9SAndroid Build Coastguard Worker pct_start=0.4, 1667*da0073e9SAndroid Build Coastguard Worker final_div_factor=4, 1668*da0073e9SAndroid Build Coastguard Worker three_phase=True, 1669*da0073e9SAndroid Build Coastguard Worker ) 1670*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1671*da0073e9SAndroid Build Coastguard Worker 1672*da0073e9SAndroid Build Coastguard Worker def test_onecycle_lr_cosine_annealing(self): 1673*da0073e9SAndroid Build Coastguard Worker def annealing_cos(start, end, pct): 1674*da0073e9SAndroid Build Coastguard Worker cos_out = math.cos(math.pi * pct) + 1 1675*da0073e9SAndroid Build Coastguard Worker return end + (start - end) / 2.0 * cos_out 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker lr_target = [ 1678*da0073e9SAndroid Build Coastguard Worker 1, 1679*da0073e9SAndroid Build Coastguard Worker 13, 1680*da0073e9SAndroid Build Coastguard Worker 25, 1681*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 1 / 7.0), 1682*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 2 / 7.0), 1683*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 3 / 7.0), 1684*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 4 / 7.0), 1685*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 5 / 7.0), 1686*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 6 / 7.0), 1687*da0073e9SAndroid Build Coastguard Worker 0.5, 1688*da0073e9SAndroid Build Coastguard Worker ] 1689*da0073e9SAndroid Build Coastguard Worker momentum_target = [ 1690*da0073e9SAndroid Build Coastguard Worker 22, 1691*da0073e9SAndroid Build Coastguard Worker 11.5, 1692*da0073e9SAndroid Build Coastguard Worker 1, 1693*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 1 / 7.0), 1694*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 2 / 7.0), 1695*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 3 / 7.0), 1696*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 4 / 7.0), 1697*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 5 / 7.0), 1698*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 6 / 7.0), 1699*da0073e9SAndroid Build Coastguard Worker 22, 1700*da0073e9SAndroid Build Coastguard Worker ] 1701*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1702*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1703*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR( 1704*da0073e9SAndroid Build Coastguard Worker self.opt, 1705*da0073e9SAndroid Build Coastguard Worker max_lr=25, 1706*da0073e9SAndroid Build Coastguard Worker final_div_factor=2, 1707*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1708*da0073e9SAndroid Build Coastguard Worker max_momentum=22, 1709*da0073e9SAndroid Build Coastguard Worker total_steps=10, 1710*da0073e9SAndroid Build Coastguard Worker ) 1711*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard Worker def test_onecycle_lr_legacy_state_dict(self): 1714*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR( 1715*da0073e9SAndroid Build Coastguard Worker self.opt, 1716*da0073e9SAndroid Build Coastguard Worker max_lr=25, 1717*da0073e9SAndroid Build Coastguard Worker final_div_factor=2, 1718*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1719*da0073e9SAndroid Build Coastguard Worker max_momentum=22, 1720*da0073e9SAndroid Build Coastguard Worker total_steps=10, 1721*da0073e9SAndroid Build Coastguard Worker anneal_strategy="cos", 1722*da0073e9SAndroid Build Coastguard Worker ) 1723*da0073e9SAndroid Build Coastguard Worker delattr(scheduler, "_anneal_func_type") 1724*da0073e9SAndroid Build Coastguard Worker state_dict = scheduler.state_dict() 1725*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("anneal_func_type", state_dict) 1726*da0073e9SAndroid Build Coastguard Worker state_dict["anneal_func"] = OneCycleLR._annealing_cos 1727*da0073e9SAndroid Build Coastguard Worker scheduler.load_state_dict(state_dict) 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker def annealing_cos(start, end, pct): 1730*da0073e9SAndroid Build Coastguard Worker cos_out = math.cos(math.pi * pct) + 1 1731*da0073e9SAndroid Build Coastguard Worker return end + (start - end) / 2.0 * cos_out 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker lr_target = [ 1734*da0073e9SAndroid Build Coastguard Worker 1, 1735*da0073e9SAndroid Build Coastguard Worker 13, 1736*da0073e9SAndroid Build Coastguard Worker 25, 1737*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 1 / 7.0), 1738*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 2 / 7.0), 1739*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 3 / 7.0), 1740*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 4 / 7.0), 1741*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 5 / 7.0), 1742*da0073e9SAndroid Build Coastguard Worker annealing_cos(25, 0.5, 6 / 7.0), 1743*da0073e9SAndroid Build Coastguard Worker 0.5, 1744*da0073e9SAndroid Build Coastguard Worker ] 1745*da0073e9SAndroid Build Coastguard Worker momentum_target = [ 1746*da0073e9SAndroid Build Coastguard Worker 22, 1747*da0073e9SAndroid Build Coastguard Worker 11.5, 1748*da0073e9SAndroid Build Coastguard Worker 1, 1749*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 1 / 7.0), 1750*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 2 / 7.0), 1751*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 3 / 7.0), 1752*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 4 / 7.0), 1753*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 5 / 7.0), 1754*da0073e9SAndroid Build Coastguard Worker annealing_cos(1, 22, 6 / 7.0), 1755*da0073e9SAndroid Build Coastguard Worker 22, 1756*da0073e9SAndroid Build Coastguard Worker ] 1757*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1758*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1759*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker def test_cycle_lr_with_adam(self): 1762*da0073e9SAndroid Build Coastguard Worker old_opt = self.opt 1763*da0073e9SAndroid Build Coastguard Worker self.opt = Adam( 1764*da0073e9SAndroid Build Coastguard Worker [ 1765*da0073e9SAndroid Build Coastguard Worker {"params": self.net.conv1.parameters()}, 1766*da0073e9SAndroid Build Coastguard Worker {"params": self.net.conv2.parameters(), "lr": 0.5}, 1767*da0073e9SAndroid Build Coastguard Worker ], 1768*da0073e9SAndroid Build Coastguard Worker lr=0.05, 1769*da0073e9SAndroid Build Coastguard Worker ) 1770*da0073e9SAndroid Build Coastguard Worker 1771*da0073e9SAndroid Build Coastguard Worker lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] 1772*da0073e9SAndroid Build Coastguard Worker momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] 1773*da0073e9SAndroid Build Coastguard Worker lr_targets = [lr_target, lr_target] 1774*da0073e9SAndroid Build Coastguard Worker momentum_targets = [momentum_target, momentum_target] 1775*da0073e9SAndroid Build Coastguard Worker scheduler = OneCycleLR( 1776*da0073e9SAndroid Build Coastguard Worker self.opt, 1777*da0073e9SAndroid Build Coastguard Worker max_lr=25, 1778*da0073e9SAndroid Build Coastguard Worker final_div_factor=2, 1779*da0073e9SAndroid Build Coastguard Worker base_momentum=1, 1780*da0073e9SAndroid Build Coastguard Worker max_momentum=22, 1781*da0073e9SAndroid Build Coastguard Worker total_steps=10, 1782*da0073e9SAndroid Build Coastguard Worker anneal_strategy="linear", 1783*da0073e9SAndroid Build Coastguard Worker ) 1784*da0073e9SAndroid Build Coastguard Worker self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True) 1785*da0073e9SAndroid Build Coastguard Worker self.opt = old_opt # set optimizer back to SGD 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Worker def test_lambda_lr(self): 1788*da0073e9SAndroid Build Coastguard Worker epochs = 10 1789*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[0]["lr"] = 0.05 1790*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[1]["lr"] = 0.4 1791*da0073e9SAndroid Build Coastguard Worker targets = [ 1792*da0073e9SAndroid Build Coastguard Worker [0.05 * (0.9**x) for x in range(epochs)], 1793*da0073e9SAndroid Build Coastguard Worker [0.4 * (0.8**x) for x in range(epochs)], 1794*da0073e9SAndroid Build Coastguard Worker ] 1795*da0073e9SAndroid Build Coastguard Worker scheduler = LambdaLR( 1796*da0073e9SAndroid Build Coastguard Worker self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2] 1797*da0073e9SAndroid Build Coastguard Worker ) 1798*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 1799*da0073e9SAndroid Build Coastguard Worker 1800*da0073e9SAndroid Build Coastguard Worker def test_multiplicative_lr(self): 1801*da0073e9SAndroid Build Coastguard Worker epochs = 10 1802*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[0]["lr"] = 0.05 1803*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[1]["lr"] = 0.4 1804*da0073e9SAndroid Build Coastguard Worker targets = [ 1805*da0073e9SAndroid Build Coastguard Worker [0.05 * (0.9**x) for x in range(epochs)], 1806*da0073e9SAndroid Build Coastguard Worker [0.4 * (0.8**x) for x in range(epochs)], 1807*da0073e9SAndroid Build Coastguard Worker ] 1808*da0073e9SAndroid Build Coastguard Worker scheduler = MultiplicativeLR( 1809*da0073e9SAndroid Build Coastguard Worker self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8] 1810*da0073e9SAndroid Build Coastguard Worker ) 1811*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, epochs) 1812*da0073e9SAndroid Build Coastguard Worker 1813*da0073e9SAndroid Build Coastguard Worker @parametrize("T_mult", [1, 2, 4]) 1814*da0073e9SAndroid Build Coastguard Worker def test_CosineAnnealingWarmRestarts_lr1(self, T_mult): 1815*da0073e9SAndroid Build Coastguard Worker iters = 100 1816*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 1817*da0073e9SAndroid Build Coastguard Worker T_i = 10 1818*da0073e9SAndroid Build Coastguard Worker T_cur = 0 1819*da0073e9SAndroid Build Coastguard Worker targets = [[0.05], [0.5]] 1820*da0073e9SAndroid Build Coastguard Worker scheduler = CosineAnnealingWarmRestarts( 1821*da0073e9SAndroid Build Coastguard Worker self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min 1822*da0073e9SAndroid Build Coastguard Worker ) 1823*da0073e9SAndroid Build Coastguard Worker for _ in range(1, iters, 1): 1824*da0073e9SAndroid Build Coastguard Worker T_cur += 1 1825*da0073e9SAndroid Build Coastguard Worker if T_cur >= T_i: 1826*da0073e9SAndroid Build Coastguard Worker T_cur = T_cur - T_i 1827*da0073e9SAndroid Build Coastguard Worker T_i = int(T_mult) * T_i 1828*da0073e9SAndroid Build Coastguard Worker targets[0] += [ 1829*da0073e9SAndroid Build Coastguard Worker eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1830*da0073e9SAndroid Build Coastguard Worker ] 1831*da0073e9SAndroid Build Coastguard Worker targets[1] += [ 1832*da0073e9SAndroid Build Coastguard Worker eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1833*da0073e9SAndroid Build Coastguard Worker ] 1834*da0073e9SAndroid Build Coastguard Worker self._test(scheduler, targets, iters) 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker def test_CosineAnnealingWarmRestarts_lr2(self): 1837*da0073e9SAndroid Build Coastguard Worker iters = 30 1838*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 1839*da0073e9SAndroid Build Coastguard Worker T_mults = [1, 2, 4] 1840*da0073e9SAndroid Build Coastguard Worker for T_mult in T_mults: 1841*da0073e9SAndroid Build Coastguard Worker T_i = 10 1842*da0073e9SAndroid Build Coastguard Worker T_cur = 0 1843*da0073e9SAndroid Build Coastguard Worker targets = [[0.05], [0.5]] 1844*da0073e9SAndroid Build Coastguard Worker scheduler = CosineAnnealingWarmRestarts( 1845*da0073e9SAndroid Build Coastguard Worker self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min 1846*da0073e9SAndroid Build Coastguard Worker ) 1847*da0073e9SAndroid Build Coastguard Worker for _ in torch.arange(0.1, iters, 0.1): 1848*da0073e9SAndroid Build Coastguard Worker T_cur = round(T_cur + 0.1, 1) 1849*da0073e9SAndroid Build Coastguard Worker if T_cur >= T_i: 1850*da0073e9SAndroid Build Coastguard Worker T_cur = T_cur - T_i 1851*da0073e9SAndroid Build Coastguard Worker T_i = int(T_mult) * T_i 1852*da0073e9SAndroid Build Coastguard Worker targets[0] += [ 1853*da0073e9SAndroid Build Coastguard Worker eta_min 1854*da0073e9SAndroid Build Coastguard Worker + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1855*da0073e9SAndroid Build Coastguard Worker ] 1856*da0073e9SAndroid Build Coastguard Worker targets[1] += [ 1857*da0073e9SAndroid Build Coastguard Worker eta_min 1858*da0073e9SAndroid Build Coastguard Worker + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1859*da0073e9SAndroid Build Coastguard Worker ] 1860*da0073e9SAndroid Build Coastguard Worker self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters) 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Worker def test_CosineAnnealingWarmRestarts_lr3(self): 1863*da0073e9SAndroid Build Coastguard Worker epochs_for_T_mults = [ 1864*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13], 1865*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3], 1866*da0073e9SAndroid Build Coastguard Worker [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50], 1867*da0073e9SAndroid Build Coastguard Worker ] 1868*da0073e9SAndroid Build Coastguard Worker T_curs_for_T_mults = [ 1869*da0073e9SAndroid Build Coastguard Worker [1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3], 1870*da0073e9SAndroid Build Coastguard Worker [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3], 1871*da0073e9SAndroid Build Coastguard Worker [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10], 1872*da0073e9SAndroid Build Coastguard Worker ] 1873*da0073e9SAndroid Build Coastguard Worker T_is_for_T_mults = [ 1874*da0073e9SAndroid Build Coastguard Worker [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], 1875*da0073e9SAndroid Build Coastguard Worker [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10], 1876*da0073e9SAndroid Build Coastguard Worker [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90], 1877*da0073e9SAndroid Build Coastguard Worker ] 1878*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 1879*da0073e9SAndroid Build Coastguard Worker T_mults = [1, 2, 3] 1880*da0073e9SAndroid Build Coastguard Worker for epochs, T_mult, T_curs, T_is in zip( 1881*da0073e9SAndroid Build Coastguard Worker epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults 1882*da0073e9SAndroid Build Coastguard Worker ): 1883*da0073e9SAndroid Build Coastguard Worker targets = [[0.05], [0.5]] 1884*da0073e9SAndroid Build Coastguard Worker scheduler = CosineAnnealingWarmRestarts( 1885*da0073e9SAndroid Build Coastguard Worker self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min 1886*da0073e9SAndroid Build Coastguard Worker ) 1887*da0073e9SAndroid Build Coastguard Worker for T_cur, T_i in zip(T_curs, T_is): 1888*da0073e9SAndroid Build Coastguard Worker targets[0] += [ 1889*da0073e9SAndroid Build Coastguard Worker eta_min 1890*da0073e9SAndroid Build Coastguard Worker + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1891*da0073e9SAndroid Build Coastguard Worker ] 1892*da0073e9SAndroid Build Coastguard Worker targets[1] += [ 1893*da0073e9SAndroid Build Coastguard Worker eta_min 1894*da0073e9SAndroid Build Coastguard Worker + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1895*da0073e9SAndroid Build Coastguard Worker ] 1896*da0073e9SAndroid Build Coastguard Worker self._test_interleaved_CosineAnnealingWarmRestarts( 1897*da0073e9SAndroid Build Coastguard Worker scheduler, targets, epochs 1898*da0073e9SAndroid Build Coastguard Worker ) 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker def test_swalr_no_anneal(self): 1901*da0073e9SAndroid Build Coastguard Worker epochs, swa_start, swa_lr = 10, 5, 0.01 1902*da0073e9SAndroid Build Coastguard Worker initial_lrs = [group["lr"] for group in self.opt.param_groups] 1903*da0073e9SAndroid Build Coastguard Worker targets = [ 1904*da0073e9SAndroid Build Coastguard Worker [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1) 1905*da0073e9SAndroid Build Coastguard Worker for lr in initial_lrs 1906*da0073e9SAndroid Build Coastguard Worker ] 1907*da0073e9SAndroid Build Coastguard Worker swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr) 1908*da0073e9SAndroid Build Coastguard Worker self._test_swalr(swa_scheduler, None, targets, swa_start, epochs) 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker def test_swalr_cosine_anneal_after_multiplicative(self): 1911*da0073e9SAndroid Build Coastguard Worker # same swa_lr for different param_groups 1912*da0073e9SAndroid Build Coastguard Worker epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5 1913*da0073e9SAndroid Build Coastguard Worker mult_factor = 0.9 1914*da0073e9SAndroid Build Coastguard Worker scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) 1915*da0073e9SAndroid Build Coastguard Worker swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr) 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker def anneal_coef(t): 1918*da0073e9SAndroid Build Coastguard Worker if t + 1 >= anneal_epochs: 1919*da0073e9SAndroid Build Coastguard Worker return 0.0 1920*da0073e9SAndroid Build Coastguard Worker return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker initial_lrs = [group["lr"] for group in self.opt.param_groups] 1923*da0073e9SAndroid Build Coastguard Worker targets_before_swa = [ 1924*da0073e9SAndroid Build Coastguard Worker [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs 1925*da0073e9SAndroid Build Coastguard Worker ] 1926*da0073e9SAndroid Build Coastguard Worker swa_epochs = epochs - swa_start - 1 1927*da0073e9SAndroid Build Coastguard Worker targets = [ 1928*da0073e9SAndroid Build Coastguard Worker lrs 1929*da0073e9SAndroid Build Coastguard Worker + [ 1930*da0073e9SAndroid Build Coastguard Worker lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) 1931*da0073e9SAndroid Build Coastguard Worker for t in range(swa_epochs) 1932*da0073e9SAndroid Build Coastguard Worker ] 1933*da0073e9SAndroid Build Coastguard Worker for lrs in targets_before_swa 1934*da0073e9SAndroid Build Coastguard Worker ] 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) 1937*da0073e9SAndroid Build Coastguard Worker 1938*da0073e9SAndroid Build Coastguard Worker def test_swalr_linear_anneal_after_multiplicative(self): 1939*da0073e9SAndroid Build Coastguard Worker # separate swa_lr for different param_groups 1940*da0073e9SAndroid Build Coastguard Worker epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4 1941*da0073e9SAndroid Build Coastguard Worker mult_factor = 0.9 1942*da0073e9SAndroid Build Coastguard Worker scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) 1943*da0073e9SAndroid Build Coastguard Worker swa_scheduler = SWALR( 1944*da0073e9SAndroid Build Coastguard Worker self.opt, 1945*da0073e9SAndroid Build Coastguard Worker anneal_epochs=anneal_epochs, 1946*da0073e9SAndroid Build Coastguard Worker anneal_strategy="linear", 1947*da0073e9SAndroid Build Coastguard Worker swa_lr=swa_lrs, 1948*da0073e9SAndroid Build Coastguard Worker ) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker def anneal_coef(t): 1951*da0073e9SAndroid Build Coastguard Worker if t + 1 >= anneal_epochs: 1952*da0073e9SAndroid Build Coastguard Worker return 0.0 1953*da0073e9SAndroid Build Coastguard Worker return 1 - (t + 1) / anneal_epochs 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker initial_lrs = [group["lr"] for group in self.opt.param_groups] 1956*da0073e9SAndroid Build Coastguard Worker targets_before_swa = [ 1957*da0073e9SAndroid Build Coastguard Worker [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs 1958*da0073e9SAndroid Build Coastguard Worker ] 1959*da0073e9SAndroid Build Coastguard Worker swa_epochs = epochs - swa_start - 1 1960*da0073e9SAndroid Build Coastguard Worker targets = [ 1961*da0073e9SAndroid Build Coastguard Worker lrs 1962*da0073e9SAndroid Build Coastguard Worker + [ 1963*da0073e9SAndroid Build Coastguard Worker lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) 1964*da0073e9SAndroid Build Coastguard Worker for t in range(swa_epochs) 1965*da0073e9SAndroid Build Coastguard Worker ] 1966*da0073e9SAndroid Build Coastguard Worker for lrs, swa_lr in zip(targets_before_swa, swa_lrs) 1967*da0073e9SAndroid Build Coastguard Worker ] 1968*da0073e9SAndroid Build Coastguard Worker 1969*da0073e9SAndroid Build Coastguard Worker self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) 1970*da0073e9SAndroid Build Coastguard Worker 1971*da0073e9SAndroid Build Coastguard Worker def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs): 1972*da0073e9SAndroid Build Coastguard Worker for epoch in range(epochs): 1973*da0073e9SAndroid Build Coastguard Worker for param_group, target in zip(self.opt.param_groups, targets): 1974*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1975*da0073e9SAndroid Build Coastguard Worker target[epoch], 1976*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 1977*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in epoch {}: expected {}, got {}".format( 1978*da0073e9SAndroid Build Coastguard Worker epoch, target[epoch], param_group["lr"] 1979*da0073e9SAndroid Build Coastguard Worker ), 1980*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 1981*da0073e9SAndroid Build Coastguard Worker rtol=0, 1982*da0073e9SAndroid Build Coastguard Worker ) 1983*da0073e9SAndroid Build Coastguard Worker if epoch >= swa_start: 1984*da0073e9SAndroid Build Coastguard Worker self.opt.step() 1985*da0073e9SAndroid Build Coastguard Worker swa_scheduler.step() 1986*da0073e9SAndroid Build Coastguard Worker elif scheduler is not None: 1987*da0073e9SAndroid Build Coastguard Worker self.opt.step() 1988*da0073e9SAndroid Build Coastguard Worker scheduler.step() 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker def test_swalr_hypers(self): 1991*da0073e9SAndroid Build Coastguard Worker # Test that SWALR raises errors for incorrect hyper-parameters 1992*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "anneal_strategy must"): 1993*da0073e9SAndroid Build Coastguard Worker swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0) 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "anneal_epochs must"): 1996*da0073e9SAndroid Build Coastguard Worker swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0) 1997*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "anneal_epochs must"): 1998*da0073e9SAndroid Build Coastguard Worker swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0) 1999*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "swa_lr must"): 2000*da0073e9SAndroid Build Coastguard Worker swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01]) 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Worker def test_step_lr_state_dict(self): 2003*da0073e9SAndroid Build Coastguard Worker self._check_scheduler_state_dict( 2004*da0073e9SAndroid Build Coastguard Worker lambda: StepLR(self.opt, gamma=0.1, step_size=3), 2005*da0073e9SAndroid Build Coastguard Worker lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1), 2006*da0073e9SAndroid Build Coastguard Worker ) 2007*da0073e9SAndroid Build Coastguard Worker 2008*da0073e9SAndroid Build Coastguard Worker def test_multi_step_lr_state_dict(self): 2009*da0073e9SAndroid Build Coastguard Worker self._check_scheduler_state_dict( 2010*da0073e9SAndroid Build Coastguard Worker lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]), 2011*da0073e9SAndroid Build Coastguard Worker lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]), 2012*da0073e9SAndroid Build Coastguard Worker ) 2013*da0073e9SAndroid Build Coastguard Worker 2014*da0073e9SAndroid Build Coastguard Worker def test_exp_step_lr_state_dict(self): 2015*da0073e9SAndroid Build Coastguard Worker self._check_scheduler_state_dict( 2016*da0073e9SAndroid Build Coastguard Worker lambda: ExponentialLR(self.opt, gamma=0.1), 2017*da0073e9SAndroid Build Coastguard Worker lambda: ExponentialLR(self.opt, gamma=0.01), 2018*da0073e9SAndroid Build Coastguard Worker ) 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker def test_cosine_lr_state_dict(self): 2021*da0073e9SAndroid Build Coastguard Worker epochs = 10 2022*da0073e9SAndroid Build Coastguard Worker eta_min = 1e-10 2023*da0073e9SAndroid Build Coastguard Worker self._check_scheduler_state_dict( 2024*da0073e9SAndroid Build Coastguard Worker lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min), 2025*da0073e9SAndroid Build Coastguard Worker lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2), 2026*da0073e9SAndroid Build Coastguard Worker epochs=epochs, 2027*da0073e9SAndroid Build Coastguard Worker ) 2028*da0073e9SAndroid Build Coastguard Worker 2029*da0073e9SAndroid Build Coastguard Worker def test_reduce_lr_on_plateau_state_dict(self): 2030*da0073e9SAndroid Build Coastguard Worker scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2) 2031*da0073e9SAndroid Build Coastguard Worker for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]: 2032*da0073e9SAndroid Build Coastguard Worker scheduler.step(score) 2033*da0073e9SAndroid Build Coastguard Worker scheduler_copy = ReduceLROnPlateau( 2034*da0073e9SAndroid Build Coastguard Worker self.opt, mode="max", factor=0.5, patience=10 2035*da0073e9SAndroid Build Coastguard Worker ) 2036*da0073e9SAndroid Build Coastguard Worker scheduler_copy.load_state_dict(scheduler.state_dict()) 2037*da0073e9SAndroid Build Coastguard Worker for key in scheduler.__dict__.keys(): 2038*da0073e9SAndroid Build Coastguard Worker if key not in {"optimizer", "is_better"}: 2039*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2040*da0073e9SAndroid Build Coastguard Worker 2041*da0073e9SAndroid Build Coastguard Worker def test_lambda_lr_state_dict_fn(self): 2042*da0073e9SAndroid Build Coastguard Worker scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x) 2043*da0073e9SAndroid Build Coastguard Worker state = scheduler.state_dict() 2044*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(state["lr_lambdas"][0]) 2045*da0073e9SAndroid Build Coastguard Worker 2046*da0073e9SAndroid Build Coastguard Worker scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x) 2047*da0073e9SAndroid Build Coastguard Worker scheduler_copy.load_state_dict(state) 2048*da0073e9SAndroid Build Coastguard Worker for key in scheduler.__dict__.keys(): 2049*da0073e9SAndroid Build Coastguard Worker if key not in {"optimizer", "lr_lambdas"}: 2050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2051*da0073e9SAndroid Build Coastguard Worker 2052*da0073e9SAndroid Build Coastguard Worker def test_lambda_lr_state_dict_obj(self): 2053*da0073e9SAndroid Build Coastguard Worker scheduler = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(10)) 2054*da0073e9SAndroid Build Coastguard Worker state = scheduler.state_dict() 2055*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(state["lr_lambdas"][0]) 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Worker scheduler_copy = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(-1)) 2058*da0073e9SAndroid Build Coastguard Worker scheduler_copy.load_state_dict(state) 2059*da0073e9SAndroid Build Coastguard Worker for key in scheduler.__dict__.keys(): 2060*da0073e9SAndroid Build Coastguard Worker if key not in {"optimizer"}: 2061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2062*da0073e9SAndroid Build Coastguard Worker 2063*da0073e9SAndroid Build Coastguard Worker def test_CosineAnnealingWarmRestarts_lr_state_dict(self): 2064*da0073e9SAndroid Build Coastguard Worker self._check_scheduler_state_dict( 2065*da0073e9SAndroid Build Coastguard Worker lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2), 2066*da0073e9SAndroid Build Coastguard Worker lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100), 2067*da0073e9SAndroid Build Coastguard Worker ) 2068*da0073e9SAndroid Build Coastguard Worker 2069*da0073e9SAndroid Build Coastguard Worker def test_swa_lr_state_dict(self): 2070*da0073e9SAndroid Build Coastguard Worker self._check_scheduler_state_dict( 2071*da0073e9SAndroid Build Coastguard Worker lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5), 2072*da0073e9SAndroid Build Coastguard Worker lambda: SWALR( 2073*da0073e9SAndroid Build Coastguard Worker self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0 2074*da0073e9SAndroid Build Coastguard Worker ), 2075*da0073e9SAndroid Build Coastguard Worker ) 2076*da0073e9SAndroid Build Coastguard Worker 2077*da0073e9SAndroid Build Coastguard Worker def _check_scheduler_state_dict(self, constr, constr2, epochs=10): 2078*da0073e9SAndroid Build Coastguard Worker scheduler = constr() 2079*da0073e9SAndroid Build Coastguard Worker for _ in range(epochs): 2080*da0073e9SAndroid Build Coastguard Worker scheduler.optimizer.step() 2081*da0073e9SAndroid Build Coastguard Worker scheduler.step() 2082*da0073e9SAndroid Build Coastguard Worker scheduler_copy = constr2() 2083*da0073e9SAndroid Build Coastguard Worker scheduler_copy.load_state_dict(scheduler.state_dict()) 2084*da0073e9SAndroid Build Coastguard Worker for key in scheduler.__dict__.keys(): 2085*da0073e9SAndroid Build Coastguard Worker if key != "optimizer": 2086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) 2088*da0073e9SAndroid Build Coastguard Worker 2089*da0073e9SAndroid Build Coastguard Worker def _test_get_last_lr(self, schedulers, targets, epochs=10): 2090*da0073e9SAndroid Build Coastguard Worker if isinstance(schedulers, LRScheduler): 2091*da0073e9SAndroid Build Coastguard Worker schedulers = [schedulers] 2092*da0073e9SAndroid Build Coastguard Worker optimizers = {scheduler.optimizer for scheduler in schedulers} 2093*da0073e9SAndroid Build Coastguard Worker for epoch in range(epochs): 2094*da0073e9SAndroid Build Coastguard Worker result = [scheduler.get_last_lr() for scheduler in schedulers] 2095*da0073e9SAndroid Build Coastguard Worker [optimizer.step() for optimizer in optimizers] 2096*da0073e9SAndroid Build Coastguard Worker [scheduler.step() for scheduler in schedulers] 2097*da0073e9SAndroid Build Coastguard Worker target = [[t[epoch] for t in targets]] * len(schedulers) 2098*da0073e9SAndroid Build Coastguard Worker for t, r in zip(target, result): 2099*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2100*da0073e9SAndroid Build Coastguard Worker t, 2101*da0073e9SAndroid Build Coastguard Worker r, 2102*da0073e9SAndroid Build Coastguard Worker msg=f"LR is wrong in epoch {epoch}: expected {t}, got {r}", 2103*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2104*da0073e9SAndroid Build Coastguard Worker rtol=0, 2105*da0073e9SAndroid Build Coastguard Worker ) 2106*da0073e9SAndroid Build Coastguard Worker 2107*da0073e9SAndroid Build Coastguard Worker def _test_with_epoch(self, schedulers, targets, epochs=10): 2108*da0073e9SAndroid Build Coastguard Worker if isinstance(schedulers, LRScheduler): 2109*da0073e9SAndroid Build Coastguard Worker schedulers = [schedulers] 2110*da0073e9SAndroid Build Coastguard Worker optimizers = {scheduler.optimizer for scheduler in schedulers} 2111*da0073e9SAndroid Build Coastguard Worker for epoch in range(epochs): 2112*da0073e9SAndroid Build Coastguard Worker [optimizer.step() for optimizer in optimizers] 2113*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2114*da0073e9SAndroid Build Coastguard Worker [ 2115*da0073e9SAndroid Build Coastguard Worker scheduler.step(epoch) for scheduler in schedulers 2116*da0073e9SAndroid Build Coastguard Worker ] # step before assert: skip initial lr 2117*da0073e9SAndroid Build Coastguard Worker self._check_warning_is_epoch_deprecation_warning( 2118*da0073e9SAndroid Build Coastguard Worker w, num_warnings=len(schedulers) 2119*da0073e9SAndroid Build Coastguard Worker ) 2120*da0073e9SAndroid Build Coastguard Worker for param_group, target in zip(self.opt.param_groups, targets): 2121*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2122*da0073e9SAndroid Build Coastguard Worker target[epoch], 2123*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 2124*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in epoch {}: expected {}, got {}".format( 2125*da0073e9SAndroid Build Coastguard Worker epoch, target[epoch], param_group["lr"] 2126*da0073e9SAndroid Build Coastguard Worker ), 2127*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2128*da0073e9SAndroid Build Coastguard Worker rtol=0, 2129*da0073e9SAndroid Build Coastguard Worker ) 2130*da0073e9SAndroid Build Coastguard Worker 2131*da0073e9SAndroid Build Coastguard Worker def _test(self, schedulers, targets, epochs=10): 2132*da0073e9SAndroid Build Coastguard Worker if isinstance(schedulers, LRScheduler): 2133*da0073e9SAndroid Build Coastguard Worker schedulers = [schedulers] 2134*da0073e9SAndroid Build Coastguard Worker for epoch in range(epochs): 2135*da0073e9SAndroid Build Coastguard Worker for param_group, target in zip(self.opt.param_groups, targets): 2136*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2137*da0073e9SAndroid Build Coastguard Worker target[epoch], 2138*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 2139*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in epoch {}: expected {}, got {}".format( 2140*da0073e9SAndroid Build Coastguard Worker epoch, target[epoch], param_group["lr"] 2141*da0073e9SAndroid Build Coastguard Worker ), 2142*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2143*da0073e9SAndroid Build Coastguard Worker rtol=0, 2144*da0073e9SAndroid Build Coastguard Worker ) 2145*da0073e9SAndroid Build Coastguard Worker [scheduler.step() for scheduler in schedulers] 2146*da0073e9SAndroid Build Coastguard Worker 2147*da0073e9SAndroid Build Coastguard Worker def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10): 2148*da0073e9SAndroid Build Coastguard Worker for index, epoch in enumerate(torch.arange(0, epochs, 0.1)): 2149*da0073e9SAndroid Build Coastguard Worker epoch = round(epoch.item(), 1) 2150*da0073e9SAndroid Build Coastguard Worker scheduler.step(epoch) 2151*da0073e9SAndroid Build Coastguard Worker for param_group, target in zip(self.opt.param_groups, targets): 2152*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2153*da0073e9SAndroid Build Coastguard Worker target[index], 2154*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 2155*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in epoch {}: expected {}, got {}".format( 2156*da0073e9SAndroid Build Coastguard Worker epoch, target[index], param_group["lr"] 2157*da0073e9SAndroid Build Coastguard Worker ), 2158*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2159*da0073e9SAndroid Build Coastguard Worker rtol=0, 2160*da0073e9SAndroid Build Coastguard Worker ) 2161*da0073e9SAndroid Build Coastguard Worker 2162*da0073e9SAndroid Build Coastguard Worker def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs): 2163*da0073e9SAndroid Build Coastguard Worker for index, epoch in enumerate(epochs): 2164*da0073e9SAndroid Build Coastguard Worker scheduler.step(epoch) 2165*da0073e9SAndroid Build Coastguard Worker for param_group, target in zip(self.opt.param_groups, targets): 2166*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2167*da0073e9SAndroid Build Coastguard Worker target[index], 2168*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 2169*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in epoch {}: expected {}, got {}".format( 2170*da0073e9SAndroid Build Coastguard Worker epoch, target[index], param_group["lr"] 2171*da0073e9SAndroid Build Coastguard Worker ), 2172*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2173*da0073e9SAndroid Build Coastguard Worker rtol=0, 2174*da0073e9SAndroid Build Coastguard Worker ) 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10): 2177*da0073e9SAndroid Build Coastguard Worker self.setUp() 2178*da0073e9SAndroid Build Coastguard Worker targets = [] 2179*da0073e9SAndroid Build Coastguard Worker for epoch in range(epochs): 2180*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler.optimizer.step() 2181*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2182*da0073e9SAndroid Build Coastguard Worker closed_form_scheduler.step(epoch) 2183*da0073e9SAndroid Build Coastguard Worker self._check_warning_is_epoch_deprecation_warning(w) 2184*da0073e9SAndroid Build Coastguard Worker targets.append([group["lr"] for group in self.opt.param_groups]) 2185*da0073e9SAndroid Build Coastguard Worker self.setUp() 2186*da0073e9SAndroid Build Coastguard Worker for epoch in range(epochs): 2187*da0073e9SAndroid Build Coastguard Worker self.opt.step() 2188*da0073e9SAndroid Build Coastguard Worker scheduler.step() 2189*da0073e9SAndroid Build Coastguard Worker for i, param_group in enumerate(self.opt.param_groups): 2190*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2191*da0073e9SAndroid Build Coastguard Worker targets[epoch][i], 2192*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 2193*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in epoch {}: expected {}, got {}".format( 2194*da0073e9SAndroid Build Coastguard Worker epoch, targets[epoch][i], param_group["lr"] 2195*da0073e9SAndroid Build Coastguard Worker ), 2196*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2197*da0073e9SAndroid Build Coastguard Worker rtol=0, 2198*da0073e9SAndroid Build Coastguard Worker ) 2199*da0073e9SAndroid Build Coastguard Worker 2200*da0073e9SAndroid Build Coastguard Worker def _test_reduce_lr_on_plateau( 2201*da0073e9SAndroid Build Coastguard Worker self, schedulers, targets, metrics, epochs=10, verbose=False 2202*da0073e9SAndroid Build Coastguard Worker ): 2203*da0073e9SAndroid Build Coastguard Worker if isinstance(schedulers, (LRScheduler, ReduceLROnPlateau)): 2204*da0073e9SAndroid Build Coastguard Worker schedulers = [schedulers] 2205*da0073e9SAndroid Build Coastguard Worker for epoch in range(epochs): 2206*da0073e9SAndroid Build Coastguard Worker self.opt.step() 2207*da0073e9SAndroid Build Coastguard Worker for scheduler in schedulers: 2208*da0073e9SAndroid Build Coastguard Worker if isinstance(scheduler, ReduceLROnPlateau): 2209*da0073e9SAndroid Build Coastguard Worker scheduler.step(metrics[epoch]) 2210*da0073e9SAndroid Build Coastguard Worker else: 2211*da0073e9SAndroid Build Coastguard Worker scheduler.step() 2212*da0073e9SAndroid Build Coastguard Worker if verbose: 2213*da0073e9SAndroid Build Coastguard Worker print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"])) 2214*da0073e9SAndroid Build Coastguard Worker for param_group, target in zip(self.opt.param_groups, targets): 2215*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2216*da0073e9SAndroid Build Coastguard Worker target[epoch], 2217*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 2218*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in epoch {}: expected {}, got {}".format( 2219*da0073e9SAndroid Build Coastguard Worker epoch, target[epoch], param_group["lr"] 2220*da0073e9SAndroid Build Coastguard Worker ), 2221*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2222*da0073e9SAndroid Build Coastguard Worker rtol=0, 2223*da0073e9SAndroid Build Coastguard Worker ) 2224*da0073e9SAndroid Build Coastguard Worker 2225*da0073e9SAndroid Build Coastguard Worker def _test_cycle_lr( 2226*da0073e9SAndroid Build Coastguard Worker self, 2227*da0073e9SAndroid Build Coastguard Worker scheduler, 2228*da0073e9SAndroid Build Coastguard Worker lr_targets, 2229*da0073e9SAndroid Build Coastguard Worker momentum_targets, 2230*da0073e9SAndroid Build Coastguard Worker batch_iterations, 2231*da0073e9SAndroid Build Coastguard Worker verbose=False, 2232*da0073e9SAndroid Build Coastguard Worker use_beta1=False, 2233*da0073e9SAndroid Build Coastguard Worker ): 2234*da0073e9SAndroid Build Coastguard Worker for batch_num in range(batch_iterations): 2235*da0073e9SAndroid Build Coastguard Worker if verbose: 2236*da0073e9SAndroid Build Coastguard Worker if "momentum" in self.opt.param_groups[0].keys(): 2237*da0073e9SAndroid Build Coastguard Worker print( 2238*da0073e9SAndroid Build Coastguard Worker "batch{}:\tlr={},momentum={}".format( 2239*da0073e9SAndroid Build Coastguard Worker batch_num, 2240*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[0]["lr"], 2241*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[0]["momentum"], 2242*da0073e9SAndroid Build Coastguard Worker ) 2243*da0073e9SAndroid Build Coastguard Worker ) 2244*da0073e9SAndroid Build Coastguard Worker elif use_beta1 and "betas" in self.opt.param_groups[0].keys(): 2245*da0073e9SAndroid Build Coastguard Worker print( 2246*da0073e9SAndroid Build Coastguard Worker "batch{}:\tlr={},beta1={}".format( 2247*da0073e9SAndroid Build Coastguard Worker batch_num, 2248*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[0]["lr"], 2249*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups[0]["betas"][0], 2250*da0073e9SAndroid Build Coastguard Worker ) 2251*da0073e9SAndroid Build Coastguard Worker ) 2252*da0073e9SAndroid Build Coastguard Worker else: 2253*da0073e9SAndroid Build Coastguard Worker print( 2254*da0073e9SAndroid Build Coastguard Worker "batch{}:\tlr={}".format( 2255*da0073e9SAndroid Build Coastguard Worker batch_num, self.opt.param_groups[0]["lr"] 2256*da0073e9SAndroid Build Coastguard Worker ) 2257*da0073e9SAndroid Build Coastguard Worker ) 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker for param_group, lr_target, momentum_target in zip( 2260*da0073e9SAndroid Build Coastguard Worker self.opt.param_groups, lr_targets, momentum_targets 2261*da0073e9SAndroid Build Coastguard Worker ): 2262*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2263*da0073e9SAndroid Build Coastguard Worker lr_target[batch_num], 2264*da0073e9SAndroid Build Coastguard Worker param_group["lr"], 2265*da0073e9SAndroid Build Coastguard Worker msg="LR is wrong in batch_num {}: expected {}, got {}".format( 2266*da0073e9SAndroid Build Coastguard Worker batch_num, lr_target[batch_num], param_group["lr"] 2267*da0073e9SAndroid Build Coastguard Worker ), 2268*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2269*da0073e9SAndroid Build Coastguard Worker rtol=0, 2270*da0073e9SAndroid Build Coastguard Worker ) 2271*da0073e9SAndroid Build Coastguard Worker 2272*da0073e9SAndroid Build Coastguard Worker if use_beta1 and "betas" in param_group.keys(): 2273*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2274*da0073e9SAndroid Build Coastguard Worker momentum_target[batch_num], 2275*da0073e9SAndroid Build Coastguard Worker param_group["betas"][0], 2276*da0073e9SAndroid Build Coastguard Worker msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format( 2277*da0073e9SAndroid Build Coastguard Worker batch_num, 2278*da0073e9SAndroid Build Coastguard Worker momentum_target[batch_num], 2279*da0073e9SAndroid Build Coastguard Worker param_group["betas"][0], 2280*da0073e9SAndroid Build Coastguard Worker ), 2281*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2282*da0073e9SAndroid Build Coastguard Worker rtol=0, 2283*da0073e9SAndroid Build Coastguard Worker ) 2284*da0073e9SAndroid Build Coastguard Worker elif "momentum" in param_group.keys(): 2285*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2286*da0073e9SAndroid Build Coastguard Worker momentum_target[batch_num], 2287*da0073e9SAndroid Build Coastguard Worker param_group["momentum"], 2288*da0073e9SAndroid Build Coastguard Worker msg="Momentum is wrong in batch_num {}: expected {}, got {}".format( 2289*da0073e9SAndroid Build Coastguard Worker batch_num, 2290*da0073e9SAndroid Build Coastguard Worker momentum_target[batch_num], 2291*da0073e9SAndroid Build Coastguard Worker param_group["momentum"], 2292*da0073e9SAndroid Build Coastguard Worker ), 2293*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 2294*da0073e9SAndroid Build Coastguard Worker rtol=0, 2295*da0073e9SAndroid Build Coastguard Worker ) 2296*da0073e9SAndroid Build Coastguard Worker self.opt.step() 2297*da0073e9SAndroid Build Coastguard Worker scheduler.step() 2298*da0073e9SAndroid Build Coastguard Worker 2299*da0073e9SAndroid Build Coastguard Worker def test_cosine_then_cyclic(self): 2300*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/21965 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker max_lr = 0.3 2303*da0073e9SAndroid Build Coastguard Worker base_lr = 0.1 2304*da0073e9SAndroid Build Coastguard Worker optim_lr = 0.5 2305*da0073e9SAndroid Build Coastguard Worker 2306*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(2, 1) 2307*da0073e9SAndroid Build Coastguard Worker optimizer = SGD(model.parameters(), lr=optim_lr) 2308*da0073e9SAndroid Build Coastguard Worker lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR( 2309*da0073e9SAndroid Build Coastguard Worker optimizer, T_max=20, eta_min=0.1 2310*da0073e9SAndroid Build Coastguard Worker ) 2311*da0073e9SAndroid Build Coastguard Worker lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR( 2312*da0073e9SAndroid Build Coastguard Worker optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3 2313*da0073e9SAndroid Build Coastguard Worker ) 2314*da0073e9SAndroid Build Coastguard Worker 2315*da0073e9SAndroid Build Coastguard Worker for i in range(40): 2316*da0073e9SAndroid Build Coastguard Worker optimizer.step() 2317*da0073e9SAndroid Build Coastguard Worker if i <= lr_scheduler_1.T_max: 2318*da0073e9SAndroid Build Coastguard Worker lr_scheduler_1.step() 2319*da0073e9SAndroid Build Coastguard Worker else: 2320*da0073e9SAndroid Build Coastguard Worker lr_scheduler_2.step() 2321*da0073e9SAndroid Build Coastguard Worker last_lr = optimizer.param_groups[0]["lr"] 2322*da0073e9SAndroid Build Coastguard Worker 2323*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(last_lr, max_lr) 2324*da0073e9SAndroid Build Coastguard Worker 2325*da0073e9SAndroid Build Coastguard Worker @parametrize( 2326*da0073e9SAndroid Build Coastguard Worker "LRClass", 2327*da0073e9SAndroid Build Coastguard Worker [ 2328*da0073e9SAndroid Build Coastguard Worker partial(LambdaLR, lr_lambda=lambda e: e // 10), 2329*da0073e9SAndroid Build Coastguard Worker partial(MultiplicativeLR, lr_lambda=lambda: 0.95), 2330*da0073e9SAndroid Build Coastguard Worker partial(StepLR, step_size=30), 2331*da0073e9SAndroid Build Coastguard Worker partial(MultiStepLR, milestones=[30, 80]), 2332*da0073e9SAndroid Build Coastguard Worker ConstantLR, 2333*da0073e9SAndroid Build Coastguard Worker LinearLR, 2334*da0073e9SAndroid Build Coastguard Worker partial(ExponentialLR, gamma=0.9), 2335*da0073e9SAndroid Build Coastguard Worker lambda opt, **kwargs: SequentialLR( 2336*da0073e9SAndroid Build Coastguard Worker opt, 2337*da0073e9SAndroid Build Coastguard Worker schedulers=[ConstantLR(opt), ConstantLR(opt)], 2338*da0073e9SAndroid Build Coastguard Worker milestones=[2], 2339*da0073e9SAndroid Build Coastguard Worker **kwargs, 2340*da0073e9SAndroid Build Coastguard Worker ), 2341*da0073e9SAndroid Build Coastguard Worker PolynomialLR, 2342*da0073e9SAndroid Build Coastguard Worker partial(CosineAnnealingLR, T_max=10), 2343*da0073e9SAndroid Build Coastguard Worker ReduceLROnPlateau, 2344*da0073e9SAndroid Build Coastguard Worker partial(CyclicLR, base_lr=0.01, max_lr=0.1), 2345*da0073e9SAndroid Build Coastguard Worker partial(CosineAnnealingWarmRestarts, T_0=20), 2346*da0073e9SAndroid Build Coastguard Worker partial(OneCycleLR, max_lr=0.01, total_steps=10), 2347*da0073e9SAndroid Build Coastguard Worker ], 2348*da0073e9SAndroid Build Coastguard Worker ) 2349*da0073e9SAndroid Build Coastguard Worker def test_lr_scheduler_verbose_deprecation_warning(self, LRClass): 2350*da0073e9SAndroid Build Coastguard Worker """Check that a deprecating warning with verbose parameter.""" 2351*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 2352*da0073e9SAndroid Build Coastguard Worker UserWarning, "The verbose parameter is deprecated" 2353*da0073e9SAndroid Build Coastguard Worker ): 2354*da0073e9SAndroid Build Coastguard Worker LRClass(self.opt, verbose=True) 2355*da0073e9SAndroid Build Coastguard Worker 2356*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 2357*da0073e9SAndroid Build Coastguard Worker UserWarning, "The verbose parameter is deprecated" 2358*da0073e9SAndroid Build Coastguard Worker ): 2359*da0073e9SAndroid Build Coastguard Worker LRClass(self.opt, verbose=False) 2360*da0073e9SAndroid Build Coastguard Worker 2361*da0073e9SAndroid Build Coastguard Worker # No warning is raised when verbose is the default value. 2362*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 2363*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error", UserWarning) 2364*da0073e9SAndroid Build Coastguard Worker LRClass(self.opt) 2365*da0073e9SAndroid Build Coastguard Worker 2366*da0073e9SAndroid Build Coastguard Worker @parametrize( 2367*da0073e9SAndroid Build Coastguard Worker "LRClass", 2368*da0073e9SAndroid Build Coastguard Worker [ 2369*da0073e9SAndroid Build Coastguard Worker partial(LambdaLR, lr_lambda=lambda e: e // 10), 2370*da0073e9SAndroid Build Coastguard Worker partial(MultiplicativeLR, lr_lambda=lambda: 0.95), 2371*da0073e9SAndroid Build Coastguard Worker partial(StepLR, step_size=30), 2372*da0073e9SAndroid Build Coastguard Worker partial(MultiStepLR, milestones=[30, 80]), 2373*da0073e9SAndroid Build Coastguard Worker ConstantLR, 2374*da0073e9SAndroid Build Coastguard Worker LinearLR, 2375*da0073e9SAndroid Build Coastguard Worker partial(ExponentialLR, gamma=0.9), 2376*da0073e9SAndroid Build Coastguard Worker PolynomialLR, 2377*da0073e9SAndroid Build Coastguard Worker partial(CosineAnnealingLR, T_max=10), 2378*da0073e9SAndroid Build Coastguard Worker lambda opt, **kwargs: ChainedScheduler( 2379*da0073e9SAndroid Build Coastguard Worker schedulers=[ConstantLR(opt), ConstantLR(opt)], **kwargs 2380*da0073e9SAndroid Build Coastguard Worker ), 2381*da0073e9SAndroid Build Coastguard Worker lambda opt, **kwargs: SequentialLR( 2382*da0073e9SAndroid Build Coastguard Worker opt, 2383*da0073e9SAndroid Build Coastguard Worker schedulers=[ConstantLR(opt), ConstantLR(opt)], 2384*da0073e9SAndroid Build Coastguard Worker milestones=[2], 2385*da0073e9SAndroid Build Coastguard Worker **kwargs, 2386*da0073e9SAndroid Build Coastguard Worker ), 2387*da0073e9SAndroid Build Coastguard Worker ReduceLROnPlateau, 2388*da0073e9SAndroid Build Coastguard Worker partial(CyclicLR, base_lr=0.01, max_lr=0.1), 2389*da0073e9SAndroid Build Coastguard Worker partial(OneCycleLR, max_lr=0.01, total_steps=10, anneal_strategy="linear"), 2390*da0073e9SAndroid Build Coastguard Worker partial(CosineAnnealingWarmRestarts, T_0=20), 2391*da0073e9SAndroid Build Coastguard Worker ], 2392*da0073e9SAndroid Build Coastguard Worker ) 2393*da0073e9SAndroid Build Coastguard Worker @parametrize("weights_only", [True, False]) 2394*da0073e9SAndroid Build Coastguard Worker def test_lr_scheduler_state_dict_load(self, LRClass, weights_only): 2395*da0073e9SAndroid Build Coastguard Worker scheduler = LRClass(self.opt) 2396*da0073e9SAndroid Build Coastguard Worker state_dict = scheduler.state_dict() 2397*da0073e9SAndroid Build Coastguard Worker 2398*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryFile() as f: 2399*da0073e9SAndroid Build Coastguard Worker torch.save(state_dict, f) 2400*da0073e9SAndroid Build Coastguard Worker f.seek(0) 2401*da0073e9SAndroid Build Coastguard Worker state_dict_loaded = torch.load(f, weights_only=weights_only) 2402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state_dict, state_dict_loaded) 2403*da0073e9SAndroid Build Coastguard Worker # Make sure state_dict can be loaded 2404*da0073e9SAndroid Build Coastguard Worker scheduler2 = LRClass(self.opt) 2405*da0073e9SAndroid Build Coastguard Worker scheduler2.load_state_dict(state_dict_loaded) 2406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scheduler2.state_dict(), state_dict) 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker @parametrize( 2409*da0073e9SAndroid Build Coastguard Worker "LRClass", 2410*da0073e9SAndroid Build Coastguard Worker [ 2411*da0073e9SAndroid Build Coastguard Worker partial(LambdaLR, lr_lambda=lambda e: e // 10), 2412*da0073e9SAndroid Build Coastguard Worker partial(MultiplicativeLR, lr_lambda=lambda e: 0.95), 2413*da0073e9SAndroid Build Coastguard Worker partial(StepLR, step_size=30), 2414*da0073e9SAndroid Build Coastguard Worker partial(MultiStepLR, milestones=[30, 80]), 2415*da0073e9SAndroid Build Coastguard Worker ConstantLR, 2416*da0073e9SAndroid Build Coastguard Worker LinearLR, 2417*da0073e9SAndroid Build Coastguard Worker partial(ExponentialLR, gamma=0.9), 2418*da0073e9SAndroid Build Coastguard Worker PolynomialLR, 2419*da0073e9SAndroid Build Coastguard Worker partial(CosineAnnealingLR, T_max=10), 2420*da0073e9SAndroid Build Coastguard Worker partial(CosineAnnealingWarmRestarts, T_0=20), 2421*da0073e9SAndroid Build Coastguard Worker ], 2422*da0073e9SAndroid Build Coastguard Worker ) 2423*da0073e9SAndroid Build Coastguard Worker def test_constant_initial_lr(self, LRClass): 2424*da0073e9SAndroid Build Coastguard Worker # Test that the initial learning rate is constant 2425*da0073e9SAndroid Build Coastguard Worker lr = torch.as_tensor(0.1) 2426*da0073e9SAndroid Build Coastguard Worker opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2427*da0073e9SAndroid Build Coastguard Worker sch = LRClass(opt) 2428*da0073e9SAndroid Build Coastguard Worker 2429*da0073e9SAndroid Build Coastguard Worker ori_param_groups = copy.deepcopy(opt.param_groups) 2430*da0073e9SAndroid Build Coastguard Worker 2431*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2432*da0073e9SAndroid Build Coastguard Worker opt.step() 2433*da0073e9SAndroid Build Coastguard Worker sch.step(i) 2434*da0073e9SAndroid Build Coastguard Worker lr.multiply_(0.1) 2435*da0073e9SAndroid Build Coastguard Worker for group, ori_group in zip(opt.param_groups, ori_param_groups): 2436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sch.base_lrs, [0.1]) 2438*da0073e9SAndroid Build Coastguard Worker 2439*da0073e9SAndroid Build Coastguard Worker def test_constant_initial_params_cyclelr(self): 2440*da0073e9SAndroid Build Coastguard Worker # Test that the initial learning rate is constant 2441*da0073e9SAndroid Build Coastguard Worker lr = torch.as_tensor(0.1) 2442*da0073e9SAndroid Build Coastguard Worker max_lr = torch.as_tensor(0.2) 2443*da0073e9SAndroid Build Coastguard Worker base_momentum = torch.as_tensor(0.8) 2444*da0073e9SAndroid Build Coastguard Worker max_momentum = torch.as_tensor(0.9) 2445*da0073e9SAndroid Build Coastguard Worker opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2446*da0073e9SAndroid Build Coastguard Worker sch = CyclicLR( 2447*da0073e9SAndroid Build Coastguard Worker opt, 2448*da0073e9SAndroid Build Coastguard Worker base_lr=lr, 2449*da0073e9SAndroid Build Coastguard Worker max_lr=max_lr, 2450*da0073e9SAndroid Build Coastguard Worker base_momentum=base_momentum, 2451*da0073e9SAndroid Build Coastguard Worker max_momentum=max_momentum, 2452*da0073e9SAndroid Build Coastguard Worker ) 2453*da0073e9SAndroid Build Coastguard Worker ori_param_groups = copy.deepcopy(opt.param_groups) 2454*da0073e9SAndroid Build Coastguard Worker 2455*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2456*da0073e9SAndroid Build Coastguard Worker lr.multiply_(0.5) 2457*da0073e9SAndroid Build Coastguard Worker max_lr.multiply_(0.5) 2458*da0073e9SAndroid Build Coastguard Worker base_momentum.multiply_(0.5) 2459*da0073e9SAndroid Build Coastguard Worker max_momentum.multiply_(0.5) 2460*da0073e9SAndroid Build Coastguard Worker opt.step() 2461*da0073e9SAndroid Build Coastguard Worker sch.step(i) 2462*da0073e9SAndroid Build Coastguard Worker for group, ori_group in zip(opt.param_groups, ori_param_groups): 2463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) 2465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) 2466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sch.base_lrs, [0.1]) 2467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sch.max_lrs, [0.2]) 2468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["max_momentum"], 0.9) 2469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["base_momentum"], 0.8) 2470*da0073e9SAndroid Build Coastguard Worker 2471*da0073e9SAndroid Build Coastguard Worker def test_constant_initial_params_onecyclelr(self): 2472*da0073e9SAndroid Build Coastguard Worker # Test that the initial learning rate is constant 2473*da0073e9SAndroid Build Coastguard Worker lr = torch.as_tensor(0.1) 2474*da0073e9SAndroid Build Coastguard Worker base_momentum = torch.as_tensor(0.85) 2475*da0073e9SAndroid Build Coastguard Worker max_momentum = torch.as_tensor(0.95) 2476*da0073e9SAndroid Build Coastguard Worker opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2477*da0073e9SAndroid Build Coastguard Worker sch = OneCycleLR( 2478*da0073e9SAndroid Build Coastguard Worker opt, 2479*da0073e9SAndroid Build Coastguard Worker max_lr=lr, 2480*da0073e9SAndroid Build Coastguard Worker total_steps=10, 2481*da0073e9SAndroid Build Coastguard Worker base_momentum=base_momentum, 2482*da0073e9SAndroid Build Coastguard Worker max_momentum=max_momentum, 2483*da0073e9SAndroid Build Coastguard Worker ) 2484*da0073e9SAndroid Build Coastguard Worker ori_param_groups = copy.deepcopy(opt.param_groups) 2485*da0073e9SAndroid Build Coastguard Worker 2486*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2487*da0073e9SAndroid Build Coastguard Worker lr.multiply_(0.5) 2488*da0073e9SAndroid Build Coastguard Worker base_momentum.multiply_(0.5) 2489*da0073e9SAndroid Build Coastguard Worker max_momentum.multiply_(0.5) 2490*da0073e9SAndroid Build Coastguard Worker opt.step() 2491*da0073e9SAndroid Build Coastguard Worker sch.step(i) 2492*da0073e9SAndroid Build Coastguard Worker 2493*da0073e9SAndroid Build Coastguard Worker for group, ori_group in zip(opt.param_groups, ori_param_groups): 2494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["max_lr"], ori_group["max_lr"]) 2496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["min_lr"], ori_group["min_lr"]) 2497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) 2498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) 2499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["max_momentum"], 0.95) 2500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["base_momentum"], 0.85) 2501*da0073e9SAndroid Build Coastguard Worker 2502*da0073e9SAndroid Build Coastguard Worker def test_constant_initial_params_swalr(self): 2503*da0073e9SAndroid Build Coastguard Worker # Test that the initial learning rate is constant 2504*da0073e9SAndroid Build Coastguard Worker lr = torch.as_tensor(0.1) 2505*da0073e9SAndroid Build Coastguard Worker swa_lr = torch.as_tensor(0.05) 2506*da0073e9SAndroid Build Coastguard Worker opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2507*da0073e9SAndroid Build Coastguard Worker sch = SWALR(opt, swa_lr=swa_lr) 2508*da0073e9SAndroid Build Coastguard Worker ori_param_groups = copy.deepcopy(opt.param_groups) 2509*da0073e9SAndroid Build Coastguard Worker 2510*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2511*da0073e9SAndroid Build Coastguard Worker lr.multiply_(0.5) 2512*da0073e9SAndroid Build Coastguard Worker swa_lr.multiply_(0.5) 2513*da0073e9SAndroid Build Coastguard Worker opt.step() 2514*da0073e9SAndroid Build Coastguard Worker sch.step() 2515*da0073e9SAndroid Build Coastguard Worker for group, ori_group in zip(opt.param_groups, ori_param_groups): 2516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["swa_lr"], ori_group["swa_lr"]) 2518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(group["swa_lr"], 0.05) 2519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sch.base_lrs, [0.1]) 2520*da0073e9SAndroid Build Coastguard Worker 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestLRScheduler) 2523*da0073e9SAndroid Build Coastguard Worker 2524*da0073e9SAndroid Build Coastguard Worker 2525*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2526*da0073e9SAndroid Build Coastguard Worker print("These tests should be run through test/test_optim.py instead") 2527