xref: /aosp_15_r20/external/pytorch/test/optim/test_lrscheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: optimizer", "module: LrScheduler" ]
2*da0073e9SAndroid Build Coastguard Workerimport copy
3*da0073e9SAndroid Build Coastguard Workerimport math
4*da0073e9SAndroid Build Coastguard Workerimport pickle
5*da0073e9SAndroid Build Coastguard Workerimport tempfile
6*da0073e9SAndroid Build Coastguard Workerimport types
7*da0073e9SAndroid Build Coastguard Workerimport warnings
8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter
13*da0073e9SAndroid Build Coastguard Workerfrom torch.optim import Adam, Rprop, SGD
14*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.lr_scheduler import (
15*da0073e9SAndroid Build Coastguard Worker    ChainedScheduler,
16*da0073e9SAndroid Build Coastguard Worker    ConstantLR,
17*da0073e9SAndroid Build Coastguard Worker    CosineAnnealingLR,
18*da0073e9SAndroid Build Coastguard Worker    CosineAnnealingWarmRestarts,
19*da0073e9SAndroid Build Coastguard Worker    CyclicLR,
20*da0073e9SAndroid Build Coastguard Worker    EPOCH_DEPRECATION_WARNING,
21*da0073e9SAndroid Build Coastguard Worker    ExponentialLR,
22*da0073e9SAndroid Build Coastguard Worker    LambdaLR,
23*da0073e9SAndroid Build Coastguard Worker    LinearLR,
24*da0073e9SAndroid Build Coastguard Worker    LRScheduler,
25*da0073e9SAndroid Build Coastguard Worker    MultiplicativeLR,
26*da0073e9SAndroid Build Coastguard Worker    MultiStepLR,
27*da0073e9SAndroid Build Coastguard Worker    OneCycleLR,
28*da0073e9SAndroid Build Coastguard Worker    PolynomialLR,
29*da0073e9SAndroid Build Coastguard Worker    ReduceLROnPlateau,
30*da0073e9SAndroid Build Coastguard Worker    SequentialLR,
31*da0073e9SAndroid Build Coastguard Worker    StepLR,
32*da0073e9SAndroid Build Coastguard Worker)
33*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.swa_utils import SWALR
34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
35*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
36*da0073e9SAndroid Build Coastguard Worker    load_tests,
37*da0073e9SAndroid Build Coastguard Worker    parametrize,
38*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
39*da0073e9SAndroid Build Coastguard Worker    TestCase,
40*da0073e9SAndroid Build Coastguard Worker)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for
44*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
45*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerclass TestLRScheduler(TestCase):
49*da0073e9SAndroid Build Coastguard Worker    class SchedulerTestNet(torch.nn.Module):
50*da0073e9SAndroid Build Coastguard Worker        def __init__(self) -> None:
51*da0073e9SAndroid Build Coastguard Worker            super().__init__()
52*da0073e9SAndroid Build Coastguard Worker            self.conv1 = torch.nn.Conv2d(1, 1, 1)
53*da0073e9SAndroid Build Coastguard Worker            self.conv2 = torch.nn.Conv2d(1, 1, 1)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        def forward(self, x):
56*da0073e9SAndroid Build Coastguard Worker            return self.conv2(F.relu(self.conv1(x)))
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    class LambdaLRTestObject:
59*da0073e9SAndroid Build Coastguard Worker        def __init__(self, value):
60*da0073e9SAndroid Build Coastguard Worker            self.value = value
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker        def __call__(self, epoch):
63*da0073e9SAndroid Build Coastguard Worker            return self.value * epoch
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        def __eq__(self, other):
66*da0073e9SAndroid Build Coastguard Worker            if isinstance(other, self.__class__):
67*da0073e9SAndroid Build Coastguard Worker                return self.__dict__ == other.__dict__
68*da0073e9SAndroid Build Coastguard Worker            else:
69*da0073e9SAndroid Build Coastguard Worker                return False
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
74*da0073e9SAndroid Build Coastguard Worker        super().setUp()
75*da0073e9SAndroid Build Coastguard Worker        self.net = self.SchedulerTestNet()
76*da0073e9SAndroid Build Coastguard Worker        self.opt = SGD(
77*da0073e9SAndroid Build Coastguard Worker            [
78*da0073e9SAndroid Build Coastguard Worker                {"params": self.net.conv1.parameters()},
79*da0073e9SAndroid Build Coastguard Worker                {"params": self.net.conv2.parameters(), "lr": 0.5},
80*da0073e9SAndroid Build Coastguard Worker            ],
81*da0073e9SAndroid Build Coastguard Worker            lr=0.05,
82*da0073e9SAndroid Build Coastguard Worker        )
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1):
85*da0073e9SAndroid Build Coastguard Worker        """This function swallows the epoch deprecation warning which is produced when we
86*da0073e9SAndroid Build Coastguard Worker        call `scheduler.step(epoch)` with some not `None` value of `epoch`.
87*da0073e9SAndroid Build Coastguard Worker        this is deprecated, and this function will need to be removed/updated when
88*da0073e9SAndroid Build Coastguard Worker        the schedulers no longer accept the parameter at all.
89*da0073e9SAndroid Build Coastguard Worker        """
90*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(w), num_warnings)
91*da0073e9SAndroid Build Coastguard Worker        for warning in w:
92*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(warning.message.args), 1)
93*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def test_error_when_getlr_has_epoch(self):
96*da0073e9SAndroid Build Coastguard Worker        class MultiStepLR(torch.optim.lr_scheduler.LRScheduler):
97*da0073e9SAndroid Build Coastguard Worker            def __init__(self, optimizer, gamma, milestones, last_epoch=-1):
98*da0073e9SAndroid Build Coastguard Worker                self.init_lr = [group["lr"] for group in optimizer.param_groups]
99*da0073e9SAndroid Build Coastguard Worker                self.gamma = gamma
100*da0073e9SAndroid Build Coastguard Worker                self.milestones = milestones
101*da0073e9SAndroid Build Coastguard Worker                super().__init__(optimizer, last_epoch)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker            def get_lr(self, step):
104*da0073e9SAndroid Build Coastguard Worker                global_step = self.last_epoch
105*da0073e9SAndroid Build Coastguard Worker                gamma_power = (
106*da0073e9SAndroid Build Coastguard Worker                    [0]
107*da0073e9SAndroid Build Coastguard Worker                    + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m]
108*da0073e9SAndroid Build Coastguard Worker                )[-1]
109*da0073e9SAndroid Build Coastguard Worker                return [
110*da0073e9SAndroid Build Coastguard Worker                    init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr
111*da0073e9SAndroid Build Coastguard Worker                ]
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        optimizer = SGD([torch.rand(1)], lr=1)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
116*da0073e9SAndroid Build Coastguard Worker            scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20])
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo(
119*da0073e9SAndroid Build Coastguard Worker        "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames"
120*da0073e9SAndroid Build Coastguard Worker    )
121*da0073e9SAndroid Build Coastguard Worker    def test_no_cyclic_references(self):
122*da0073e9SAndroid Build Coastguard Worker        import gc
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        param = Parameter(torch.empty(10))
125*da0073e9SAndroid Build Coastguard Worker        optim = SGD([param], lr=0.5)
126*da0073e9SAndroid Build Coastguard Worker        scheduler = LambdaLR(optim, lambda epoch: 1.0)
127*da0073e9SAndroid Build Coastguard Worker        del scheduler
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
130*da0073e9SAndroid Build Coastguard Worker            len(gc.get_referrers(optim)) == 0,
131*da0073e9SAndroid Build Coastguard Worker            "Optimizer should contain no cyclic references",
132*da0073e9SAndroid Build Coastguard Worker        )
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        gc.collect()
135*da0073e9SAndroid Build Coastguard Worker        del optim
136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
137*da0073e9SAndroid Build Coastguard Worker            gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__"
138*da0073e9SAndroid Build Coastguard Worker        )
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo(
141*da0073e9SAndroid Build Coastguard Worker        "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames"
142*da0073e9SAndroid Build Coastguard Worker    )
143*da0073e9SAndroid Build Coastguard Worker    def test_no_cyclic_references_in_step(self):
144*da0073e9SAndroid Build Coastguard Worker        import gc
145*da0073e9SAndroid Build Coastguard Worker        import weakref
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        def run():
148*da0073e9SAndroid Build Coastguard Worker            param = torch.empty(10, requires_grad=True)
149*da0073e9SAndroid Build Coastguard Worker            optim = SGD(params=[param], lr=0.5)
150*da0073e9SAndroid Build Coastguard Worker            scheduler = LambdaLR(optim, lambda epoch: 1.0)
151*da0073e9SAndroid Build Coastguard Worker            param.sum().backward()
152*da0073e9SAndroid Build Coastguard Worker            optim.step()
153*da0073e9SAndroid Build Coastguard Worker            scheduler.step()
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker            return weakref.ref(scheduler)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        # To ensure that there are no reference cycles in scheduler,
158*da0073e9SAndroid Build Coastguard Worker        # we need to turn off the garbage collector. Since gc will
159*da0073e9SAndroid Build Coastguard Worker        # automatically collect unreachable objects.
160*da0073e9SAndroid Build Coastguard Worker        gc.disable()
161*da0073e9SAndroid Build Coastguard Worker        ref = run()
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        assert ref() is None
164*da0073e9SAndroid Build Coastguard Worker        gc.enable()  # restore
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    def test_old_pattern_warning(self):
167*da0073e9SAndroid Build Coastguard Worker        epochs = 35
168*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
169*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
170*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
171*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        def old_pattern():
174*da0073e9SAndroid Build Coastguard Worker            for _ in range(epochs):
175*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
176*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    def test_old_pattern_warning_with_arg(self):
181*da0073e9SAndroid Build Coastguard Worker        epochs = 35
182*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
183*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
184*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
185*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        def old_pattern2():
188*da0073e9SAndroid Build Coastguard Worker            for _ in range(epochs):
189*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
190*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    def test_old_pattern_warning_resuming(self):
195*da0073e9SAndroid Build Coastguard Worker        epochs = 35
196*da0073e9SAndroid Build Coastguard Worker        for i, group in enumerate(self.opt.param_groups):
197*da0073e9SAndroid Build Coastguard Worker            group["initial_lr"] = 0.01
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
200*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
201*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
202*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        def old_pattern():
205*da0073e9SAndroid Build Coastguard Worker            for _ in range(epochs):
206*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
207*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    def test_old_pattern_warning_resuming_with_arg(self):
212*da0073e9SAndroid Build Coastguard Worker        epochs = 35
213*da0073e9SAndroid Build Coastguard Worker        for i, group in enumerate(self.opt.param_groups):
214*da0073e9SAndroid Build Coastguard Worker            group["initial_lr"] = 0.01
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
217*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
218*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
219*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        def old_pattern2():
222*da0073e9SAndroid Build Coastguard Worker            for _ in range(epochs):
223*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
224*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    def test_old_pattern_warning_with_overridden_optim_step(self):
229*da0073e9SAndroid Build Coastguard Worker        epochs = 35
230*da0073e9SAndroid Build Coastguard Worker        for i, group in enumerate(self.opt.param_groups):
231*da0073e9SAndroid Build Coastguard Worker            group["initial_lr"] = 0.01
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
234*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
235*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
236*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        # emulate use-case with optimizer.step overridden
239*da0073e9SAndroid Build Coastguard Worker        import types
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        old_step = self.opt.step
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        def new_step(o, *args, **kwargs):
244*da0073e9SAndroid Build Coastguard Worker            retval = old_step(*args, **kwargs)
245*da0073e9SAndroid Build Coastguard Worker            return retval
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker        self.opt.step = types.MethodType(new_step, self.opt)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        def old_pattern2():
250*da0073e9SAndroid Build Coastguard Worker            for _ in range(epochs):
251*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
252*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    def test_new_pattern_no_warning(self):
257*da0073e9SAndroid Build Coastguard Worker        epochs = 35
258*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
259*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
260*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
261*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
264*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
265*da0073e9SAndroid Build Coastguard Worker            for _ in range(epochs):
266*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
267*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
268*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    def test_new_pattern_no_warning_with_arg(self):
271*da0073e9SAndroid Build Coastguard Worker        epochs = 35
272*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
273*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
274*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
275*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
278*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
279*da0073e9SAndroid Build Coastguard Worker            for _ in range(epochs):
280*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
281*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
282*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    def test_new_pattern_no_warning_with_overridden_optim_step(self):
285*da0073e9SAndroid Build Coastguard Worker        epochs = 35
286*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
287*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")  # allow any warning to be raised
288*da0073e9SAndroid Build Coastguard Worker            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
289*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(ws) == 0, "No warning should be raised")
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        # emulate use-case with optimizer.step overridden
292*da0073e9SAndroid Build Coastguard Worker        import types
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        old_step = self.opt.step
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        def new_step(o, *args, **kwargs):
297*da0073e9SAndroid Build Coastguard Worker            retval = old_step(*args, **kwargs)
298*da0073e9SAndroid Build Coastguard Worker            return retval
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker        self.opt.step = types.MethodType(new_step, self.opt)
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        def new_pattern():
303*da0073e9SAndroid Build Coastguard Worker            for e in range(epochs):
304*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
305*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        self.assertWarnsRegex(
308*da0073e9SAndroid Build Coastguard Worker            UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern
309*da0073e9SAndroid Build Coastguard Worker        )
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    def _test_lr_is_constant_for_constant_epoch(self, scheduler):
312*da0073e9SAndroid Build Coastguard Worker        l = []
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
315*da0073e9SAndroid Build Coastguard Worker            scheduler.optimizer.step()
316*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
317*da0073e9SAndroid Build Coastguard Worker                scheduler.step(2)
318*da0073e9SAndroid Build Coastguard Worker                self._check_warning_is_epoch_deprecation_warning(w)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker            l.append(self.opt.param_groups[0]["lr"])
321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(min(l), max(l))
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    def test_step_lr_is_constant_for_constant_epoch(self):
324*da0073e9SAndroid Build Coastguard Worker        scheduler = StepLR(self.opt, 2)
325*da0073e9SAndroid Build Coastguard Worker        self._test_lr_is_constant_for_constant_epoch(scheduler)
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    def test_exponential_lr_is_constant_for_constant_epoch(self):
328*da0073e9SAndroid Build Coastguard Worker        scheduler = ExponentialLR(self.opt, gamma=0.9)
329*da0073e9SAndroid Build Coastguard Worker        self._test_lr_is_constant_for_constant_epoch(scheduler)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    def test_constantlr_is_constant_for_constant_epoch(self):
332*da0073e9SAndroid Build Coastguard Worker        scheduler = ConstantLR(self.opt)
333*da0073e9SAndroid Build Coastguard Worker        self._test_lr_is_constant_for_constant_epoch(scheduler)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    def test_linear_linearlr_is_constant_for_constant_epoch(self):
336*da0073e9SAndroid Build Coastguard Worker        scheduler = LinearLR(self.opt)
337*da0073e9SAndroid Build Coastguard Worker        self._test_lr_is_constant_for_constant_epoch(scheduler)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker    def test_polynomial_lr_is_constant_for_constant_epoch(self):
340*da0073e9SAndroid Build Coastguard Worker        scheduler = PolynomialLR(self.opt, power=0.9)
341*da0073e9SAndroid Build Coastguard Worker        self._test_lr_is_constant_for_constant_epoch(scheduler)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    def test_step_lr(self):
344*da0073e9SAndroid Build Coastguard Worker        # lr = 0.05     if epoch < 3
345*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005    if 30 <= epoch < 6
346*da0073e9SAndroid Build Coastguard Worker        # lr = 0.0005   if epoch >= 9
347*da0073e9SAndroid Build Coastguard Worker        epochs = 10
348*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
349*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
350*da0073e9SAndroid Build Coastguard Worker        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
351*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    def test_get_last_lr_step_lr(self):
354*da0073e9SAndroid Build Coastguard Worker        from torch.nn import Parameter
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        epochs = 10
357*da0073e9SAndroid Build Coastguard Worker        optimizer = SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1)
358*da0073e9SAndroid Build Coastguard Worker        targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]]
359*da0073e9SAndroid Build Coastguard Worker        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)
360*da0073e9SAndroid Build Coastguard Worker        self._test_get_last_lr(scheduler, targets, epochs)
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker    def test_get_last_lr_multi_step_lr(self):
363*da0073e9SAndroid Build Coastguard Worker        # lr = 0.05     if epoch < 2
364*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005    if 2 <= epoch < 5
365*da0073e9SAndroid Build Coastguard Worker        # lr = 0.0005   if 5 <= epoch < 9
366*da0073e9SAndroid Build Coastguard Worker        # lr = 0.00005   if 9 <= epoch
367*da0073e9SAndroid Build Coastguard Worker        epochs = 10
368*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1
369*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
370*da0073e9SAndroid Build Coastguard Worker        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
371*da0073e9SAndroid Build Coastguard Worker        self._test_get_last_lr(scheduler, targets, epochs)
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    def test_multi_step_lr(self):
374*da0073e9SAndroid Build Coastguard Worker        # lr = 0.05     if epoch < 2
375*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005    if 2 <= epoch < 5
376*da0073e9SAndroid Build Coastguard Worker        # lr = 0.0005   if epoch < 9
377*da0073e9SAndroid Build Coastguard Worker        # lr = 0.00005   if epoch >= 9
378*da0073e9SAndroid Build Coastguard Worker        epochs = 10
379*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
380*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
381*da0073e9SAndroid Build Coastguard Worker        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
382*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker    def test_multi_step_lr_with_epoch(self):
385*da0073e9SAndroid Build Coastguard Worker        # lr = 0.05     if epoch < 2
386*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005    if 2 <= epoch < 5
387*da0073e9SAndroid Build Coastguard Worker        # lr = 0.0005   if epoch < 9
388*da0073e9SAndroid Build Coastguard Worker        # lr = 0.00005   if epoch >= 9
389*da0073e9SAndroid Build Coastguard Worker        epochs = 10
390*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
391*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
392*da0073e9SAndroid Build Coastguard Worker        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
393*da0073e9SAndroid Build Coastguard Worker        self._test_with_epoch(scheduler, targets, epochs)
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    def test_get_last_lr_constantlr(self):
396*da0073e9SAndroid Build Coastguard Worker        # lr = 0.025     if epoch < 5
397*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005    if 5 <= epoch
398*da0073e9SAndroid Build Coastguard Worker        epochs = 10
399*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.025] * 5 + [0.05] * 5
400*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
401*da0073e9SAndroid Build Coastguard Worker        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
402*da0073e9SAndroid Build Coastguard Worker        self._test_get_last_lr(scheduler, targets, epochs)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def test_get_last_lr_linearlr(self):
405*da0073e9SAndroid Build Coastguard Worker        # lr = 0.025     if epoch == 0
406*da0073e9SAndroid Build Coastguard Worker        # lr = 0.03125   if epoch == 1
407*da0073e9SAndroid Build Coastguard Worker        # lr = 0.0375    if epoch == 2
408*da0073e9SAndroid Build Coastguard Worker        # lr = 0.04375   if epoch == 3
409*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005     if 4 <= epoch
410*da0073e9SAndroid Build Coastguard Worker        epochs = 10
411*da0073e9SAndroid Build Coastguard Worker        start_factor = 1.0 / 4
412*da0073e9SAndroid Build Coastguard Worker        end_factor = 3.0 / 5
413*da0073e9SAndroid Build Coastguard Worker        iters = 4
414*da0073e9SAndroid Build Coastguard Worker        interpolation = [
415*da0073e9SAndroid Build Coastguard Worker            start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
416*da0073e9SAndroid Build Coastguard Worker        ]
417*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (
418*da0073e9SAndroid Build Coastguard Worker            epochs - iters
419*da0073e9SAndroid Build Coastguard Worker        )
420*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
421*da0073e9SAndroid Build Coastguard Worker        scheduler = LinearLR(
422*da0073e9SAndroid Build Coastguard Worker            self.opt,
423*da0073e9SAndroid Build Coastguard Worker            start_factor=start_factor,
424*da0073e9SAndroid Build Coastguard Worker            end_factor=end_factor,
425*da0073e9SAndroid Build Coastguard Worker            total_iters=iters,
426*da0073e9SAndroid Build Coastguard Worker        )
427*da0073e9SAndroid Build Coastguard Worker        self._test_get_last_lr(scheduler, targets, epochs)
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    def test_constantlr(self):
430*da0073e9SAndroid Build Coastguard Worker        # lr = 0.025     if epoch < 5
431*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005    if 5 <= epoch
432*da0073e9SAndroid Build Coastguard Worker        epochs = 10
433*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.025] * 5 + [0.05] * 5
434*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
435*da0073e9SAndroid Build Coastguard Worker        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
436*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker    def test_linearlr(self):
439*da0073e9SAndroid Build Coastguard Worker        # lr = 0.025     if epoch == 0
440*da0073e9SAndroid Build Coastguard Worker        # lr = 0.03125   if epoch == 1
441*da0073e9SAndroid Build Coastguard Worker        # lr = 0.0375    if epoch == 2
442*da0073e9SAndroid Build Coastguard Worker        # lr = 0.04375   if epoch == 3
443*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005     if 4 <= epoch
444*da0073e9SAndroid Build Coastguard Worker        epochs = 10
445*da0073e9SAndroid Build Coastguard Worker        start_factor = 1.0 / 2
446*da0073e9SAndroid Build Coastguard Worker        iters = 4
447*da0073e9SAndroid Build Coastguard Worker        interpolation = [
448*da0073e9SAndroid Build Coastguard Worker            start_factor + i * (1 - start_factor) / iters for i in range(iters)
449*da0073e9SAndroid Build Coastguard Worker        ]
450*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
451*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
452*da0073e9SAndroid Build Coastguard Worker        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
453*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    def test_linearlr_start_factor_limits1(self):
456*da0073e9SAndroid Build Coastguard Worker        start_factor = 0.0
457*da0073e9SAndroid Build Coastguard Worker        iters = 4
458*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
459*da0073e9SAndroid Build Coastguard Worker            LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker    def test_linearlr_start_factor_limits2(self):
462*da0073e9SAndroid Build Coastguard Worker        start_factor = 1.1
463*da0073e9SAndroid Build Coastguard Worker        iters = 4
464*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
465*da0073e9SAndroid Build Coastguard Worker            LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    def test_constantlr_with_epoch(self):
468*da0073e9SAndroid Build Coastguard Worker        # lr = 0.025     if epoch < 5
469*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005    if 5 <= epoch
470*da0073e9SAndroid Build Coastguard Worker        epochs = 10
471*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.025] * 5 + [0.05] * 5
472*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
473*da0073e9SAndroid Build Coastguard Worker        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
474*da0073e9SAndroid Build Coastguard Worker        self._test_with_epoch(scheduler, targets, epochs)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    def test_linearlr_with_epoch(self):
477*da0073e9SAndroid Build Coastguard Worker        # lr = 0.025     if epoch == 0
478*da0073e9SAndroid Build Coastguard Worker        # lr = 0.03125   if epoch == 1
479*da0073e9SAndroid Build Coastguard Worker        # lr = 0.0375    if epoch == 2
480*da0073e9SAndroid Build Coastguard Worker        # lr = 0.04375   if epoch == 3
481*da0073e9SAndroid Build Coastguard Worker        # lr = 0.005     if 4 <= epoch
482*da0073e9SAndroid Build Coastguard Worker        epochs = 10
483*da0073e9SAndroid Build Coastguard Worker        start_factor = 1.0 / 2
484*da0073e9SAndroid Build Coastguard Worker        end_factor = 1.0
485*da0073e9SAndroid Build Coastguard Worker        iters = 4
486*da0073e9SAndroid Build Coastguard Worker        interpolation = [
487*da0073e9SAndroid Build Coastguard Worker            start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
488*da0073e9SAndroid Build Coastguard Worker        ]
489*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
490*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
491*da0073e9SAndroid Build Coastguard Worker        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
492*da0073e9SAndroid Build Coastguard Worker        self._test_with_epoch(scheduler, targets, epochs)
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker    def test_exp_lr(self):
495*da0073e9SAndroid Build Coastguard Worker        epochs = 10
496*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05 * (0.9**x) for x in range(epochs)]
497*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
498*da0073e9SAndroid Build Coastguard Worker        scheduler = ExponentialLR(self.opt, gamma=0.9)
499*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    def test_poly_lr(self):
502*da0073e9SAndroid Build Coastguard Worker        epochs = 10
503*da0073e9SAndroid Build Coastguard Worker        power = 0.9
504*da0073e9SAndroid Build Coastguard Worker        total_iters = 5
505*da0073e9SAndroid Build Coastguard Worker        single_targets = [
506*da0073e9SAndroid Build Coastguard Worker            (1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters)
507*da0073e9SAndroid Build Coastguard Worker        ] + [0.0] * (epochs - total_iters)
508*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
509*da0073e9SAndroid Build Coastguard Worker        scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters)
510*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    def test_cos_anneal_lr(self):
513*da0073e9SAndroid Build Coastguard Worker        epochs = 10
514*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
515*da0073e9SAndroid Build Coastguard Worker        single_targets = [
516*da0073e9SAndroid Build Coastguard Worker            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
517*da0073e9SAndroid Build Coastguard Worker            for x in range(epochs)
518*da0073e9SAndroid Build Coastguard Worker        ]
519*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
520*da0073e9SAndroid Build Coastguard Worker        scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
521*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    def test_closed_form_step_lr(self):
524*da0073e9SAndroid Build Coastguard Worker        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
525*da0073e9SAndroid Build Coastguard Worker        closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
526*da0073e9SAndroid Build Coastguard Worker        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker    def test_closed_form_linearlr(self):
529*da0073e9SAndroid Build Coastguard Worker        scheduler = LinearLR(
530*da0073e9SAndroid Build Coastguard Worker            self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
531*da0073e9SAndroid Build Coastguard Worker        )
532*da0073e9SAndroid Build Coastguard Worker        closed_form_scheduler = LinearLR(
533*da0073e9SAndroid Build Coastguard Worker            self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
534*da0073e9SAndroid Build Coastguard Worker        )
535*da0073e9SAndroid Build Coastguard Worker        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker    def test_closed_form_constantlr(self):
538*da0073e9SAndroid Build Coastguard Worker        scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
539*da0073e9SAndroid Build Coastguard Worker        closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
540*da0073e9SAndroid Build Coastguard Worker        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker    def test_closed_form_multi_step_lr(self):
543*da0073e9SAndroid Build Coastguard Worker        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
544*da0073e9SAndroid Build Coastguard Worker        closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
545*da0073e9SAndroid Build Coastguard Worker        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker    def test_closed_form_exp_lr(self):
548*da0073e9SAndroid Build Coastguard Worker        scheduler = ExponentialLR(self.opt, gamma=0.9)
549*da0073e9SAndroid Build Coastguard Worker        closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9)
550*da0073e9SAndroid Build Coastguard Worker        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker    def test_closed_form_poly_lr(self):
553*da0073e9SAndroid Build Coastguard Worker        scheduler = PolynomialLR(self.opt, power=0.9)
554*da0073e9SAndroid Build Coastguard Worker        closed_form_scheduler = PolynomialLR(self.opt, power=0.9)
555*da0073e9SAndroid Build Coastguard Worker        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker    def test_closed_form_cos_anneal_lr(self):
558*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
559*da0073e9SAndroid Build Coastguard Worker        epochs = 20
560*da0073e9SAndroid Build Coastguard Worker        T_max = 5
561*da0073e9SAndroid Build Coastguard Worker        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
562*da0073e9SAndroid Build Coastguard Worker        closed_form_scheduler = CosineAnnealingLR(
563*da0073e9SAndroid Build Coastguard Worker            self.opt, T_max=T_max, eta_min=eta_min
564*da0073e9SAndroid Build Coastguard Worker        )
565*da0073e9SAndroid Build Coastguard Worker        self._test_against_closed_form(scheduler, closed_form_scheduler, epochs)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    def test_cos_anneal_lr_continue(self):
568*da0073e9SAndroid Build Coastguard Worker        eta_min = 0.1
569*da0073e9SAndroid Build Coastguard Worker        T_max = 5
570*da0073e9SAndroid Build Coastguard Worker        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
571*da0073e9SAndroid Build Coastguard Worker        self.opt.step()
572*da0073e9SAndroid Build Coastguard Worker        scheduler.step()
573*da0073e9SAndroid Build Coastguard Worker        original_lrs = scheduler._last_lr
574*da0073e9SAndroid Build Coastguard Worker        new_scheduler = CosineAnnealingLR(
575*da0073e9SAndroid Build Coastguard Worker            self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0
576*da0073e9SAndroid Build Coastguard Worker        )
577*da0073e9SAndroid Build Coastguard Worker        new_lrs = new_scheduler._last_lr
578*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5)
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau1(self):
581*da0073e9SAndroid Build Coastguard Worker        epochs = 10
582*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
583*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
584*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * 20]
585*da0073e9SAndroid Build Coastguard Worker        metrics = [10 - i * 0.0167 for i in range(20)]
586*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
587*da0073e9SAndroid Build Coastguard Worker            self.opt,
588*da0073e9SAndroid Build Coastguard Worker            threshold_mode="abs",
589*da0073e9SAndroid Build Coastguard Worker            mode="min",
590*da0073e9SAndroid Build Coastguard Worker            threshold=0.01,
591*da0073e9SAndroid Build Coastguard Worker            patience=5,
592*da0073e9SAndroid Build Coastguard Worker            cooldown=5,
593*da0073e9SAndroid Build Coastguard Worker        )
594*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau2(self):
597*da0073e9SAndroid Build Coastguard Worker        epochs = 22
598*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
599*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
600*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
601*da0073e9SAndroid Build Coastguard Worker        metrics = [10 - i * 0.0165 for i in range(22)]
602*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
603*da0073e9SAndroid Build Coastguard Worker            self.opt,
604*da0073e9SAndroid Build Coastguard Worker            patience=5,
605*da0073e9SAndroid Build Coastguard Worker            cooldown=0,
606*da0073e9SAndroid Build Coastguard Worker            threshold_mode="abs",
607*da0073e9SAndroid Build Coastguard Worker            mode="min",
608*da0073e9SAndroid Build Coastguard Worker            threshold=0.1,
609*da0073e9SAndroid Build Coastguard Worker        )
610*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau3(self):
613*da0073e9SAndroid Build Coastguard Worker        epochs = 22
614*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
615*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
616*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
617*da0073e9SAndroid Build Coastguard Worker        metrics = [-0.8] * 2 + [-0.234] * 20
618*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
619*da0073e9SAndroid Build Coastguard Worker            self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
620*da0073e9SAndroid Build Coastguard Worker        )
621*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau4(self):
624*da0073e9SAndroid Build Coastguard Worker        epochs = 20
625*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
626*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
627*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * 20]
628*da0073e9SAndroid Build Coastguard Worker        metrics = [1.5 * (1.025**i) for i in range(20)]  # 1.025 > 1.1**0.25
629*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
630*da0073e9SAndroid Build Coastguard Worker            self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
631*da0073e9SAndroid Build Coastguard Worker        )
632*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau5(self):
635*da0073e9SAndroid Build Coastguard Worker        epochs = 20
636*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
637*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
638*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
639*da0073e9SAndroid Build Coastguard Worker        metrics = [1.5 * (1.005**i) for i in range(20)]
640*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
641*da0073e9SAndroid Build Coastguard Worker            self.opt,
642*da0073e9SAndroid Build Coastguard Worker            mode="max",
643*da0073e9SAndroid Build Coastguard Worker            threshold_mode="rel",
644*da0073e9SAndroid Build Coastguard Worker            threshold=0.1,
645*da0073e9SAndroid Build Coastguard Worker            patience=5,
646*da0073e9SAndroid Build Coastguard Worker            cooldown=5,
647*da0073e9SAndroid Build Coastguard Worker        )
648*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau6(self):
651*da0073e9SAndroid Build Coastguard Worker        epochs = 20
652*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
653*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
654*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * 20]
655*da0073e9SAndroid Build Coastguard Worker        metrics = [1.5 * (0.85**i) for i in range(20)]
656*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
657*da0073e9SAndroid Build Coastguard Worker            self.opt, mode="min", threshold_mode="rel", threshold=0.1
658*da0073e9SAndroid Build Coastguard Worker        )
659*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau7(self):
662*da0073e9SAndroid Build Coastguard Worker        epochs = 20
663*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
664*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
665*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
666*da0073e9SAndroid Build Coastguard Worker        metrics = [1] * 7 + [0.6] + [0.5] * 12
667*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
668*da0073e9SAndroid Build Coastguard Worker            self.opt,
669*da0073e9SAndroid Build Coastguard Worker            mode="min",
670*da0073e9SAndroid Build Coastguard Worker            threshold_mode="rel",
671*da0073e9SAndroid Build Coastguard Worker            threshold=0.1,
672*da0073e9SAndroid Build Coastguard Worker            patience=5,
673*da0073e9SAndroid Build Coastguard Worker            cooldown=5,
674*da0073e9SAndroid Build Coastguard Worker        )
675*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau8(self):
678*da0073e9SAndroid Build Coastguard Worker        epochs = 20
679*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
680*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
681*da0073e9SAndroid Build Coastguard Worker        targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
682*da0073e9SAndroid Build Coastguard Worker        metrics = [1.5 * (1.005**i) for i in range(20)]
683*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
684*da0073e9SAndroid Build Coastguard Worker            self.opt,
685*da0073e9SAndroid Build Coastguard Worker            mode="max",
686*da0073e9SAndroid Build Coastguard Worker            threshold_mode="rel",
687*da0073e9SAndroid Build Coastguard Worker            min_lr=[0.4, 0.3],
688*da0073e9SAndroid Build Coastguard Worker            threshold=0.1,
689*da0073e9SAndroid Build Coastguard Worker            patience=5,
690*da0073e9SAndroid Build Coastguard Worker            cooldown=5,
691*da0073e9SAndroid Build Coastguard Worker        )
692*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau_get_last_lr_before_step(self):
695*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
696*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
697*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(
698*da0073e9SAndroid Build Coastguard Worker            self.opt,
699*da0073e9SAndroid Build Coastguard Worker        )
700*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
701*da0073e9SAndroid Build Coastguard Worker            scheduler.get_last_lr(), [0.5 for param_group in self.opt.param_groups]
702*da0073e9SAndroid Build Coastguard Worker        )
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker    def test_sequentiallr1(self):
705*da0073e9SAndroid Build Coastguard Worker        epochs = 19
706*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
707*da0073e9SAndroid Build Coastguard Worker        targets = [
708*da0073e9SAndroid Build Coastguard Worker            [0.05, 0.04, 0.032]
709*da0073e9SAndroid Build Coastguard Worker            + [0.05 for x in range(4)]
710*da0073e9SAndroid Build Coastguard Worker            + [0.05 * 0.1 for x in range(4)]
711*da0073e9SAndroid Build Coastguard Worker            + [0.05 * 0.01 for x in range(4)]
712*da0073e9SAndroid Build Coastguard Worker            + [0.05 * 0.001 for x in range(4)]
713*da0073e9SAndroid Build Coastguard Worker        ]
714*da0073e9SAndroid Build Coastguard Worker        milestones = [3]
715*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ExponentialLR(self.opt, gamma=0.8)
716*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4)
717*da0073e9SAndroid Build Coastguard Worker        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
718*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker    def test_sequentiallr2(self):
721*da0073e9SAndroid Build Coastguard Worker        epochs = 13
722*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
723*da0073e9SAndroid Build Coastguard Worker        targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]]
724*da0073e9SAndroid Build Coastguard Worker        milestones = [3]
725*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
726*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
727*da0073e9SAndroid Build Coastguard Worker        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
728*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    def test_sequentiallr3(self):
731*da0073e9SAndroid Build Coastguard Worker        epochs = 12
732*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 3
733*da0073e9SAndroid Build Coastguard Worker        targets = [
734*da0073e9SAndroid Build Coastguard Worker            [0.005, 0.005, 0.005]
735*da0073e9SAndroid Build Coastguard Worker            + [0.05, 0.04, 0.032]
736*da0073e9SAndroid Build Coastguard Worker            + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
737*da0073e9SAndroid Build Coastguard Worker        ]
738*da0073e9SAndroid Build Coastguard Worker        milestones = [3, 6]
739*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
740*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
741*da0073e9SAndroid Build Coastguard Worker        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
742*da0073e9SAndroid Build Coastguard Worker        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
743*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    def test_sequentiallr4(self):
746*da0073e9SAndroid Build Coastguard Worker        optimizer = SGD([torch.tensor(0.5)], lr=0.1)
747*da0073e9SAndroid Build Coastguard Worker        prev_lr = optimizer.param_groups[0]["lr"]
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker        schedulers = [
750*da0073e9SAndroid Build Coastguard Worker            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1),
751*da0073e9SAndroid Build Coastguard Worker            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1),
752*da0073e9SAndroid Build Coastguard Worker        ]
753*da0073e9SAndroid Build Coastguard Worker        scheduler = torch.optim.lr_scheduler.SequentialLR(
754*da0073e9SAndroid Build Coastguard Worker            optimizer, schedulers, milestones=[10]
755*da0073e9SAndroid Build Coastguard Worker        )
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker        new_lr = optimizer.param_groups[0]["lr"]
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker        # Ensure that multiple schedulers does not affect the initial learning rate
760*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prev_lr, new_lr)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker    def test_get_last_lr_sequentiallr(self):
763*da0073e9SAndroid Build Coastguard Worker        epochs = 12
764*da0073e9SAndroid Build Coastguard Worker        milestones = [3, 6]
765*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 3
766*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
767*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
768*da0073e9SAndroid Build Coastguard Worker        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
769*da0073e9SAndroid Build Coastguard Worker        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
770*da0073e9SAndroid Build Coastguard Worker        constant_lr_target = [0.005] * 3
771*da0073e9SAndroid Build Coastguard Worker        exponential_lr_target = [0.05, 0.04, 0.032]
772*da0073e9SAndroid Build Coastguard Worker        step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
773*da0073e9SAndroid Build Coastguard Worker        single_targets = constant_lr_target + exponential_lr_target + step_lr_target
774*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * 10 for x in single_targets]]
775*da0073e9SAndroid Build Coastguard Worker        self._test_get_last_lr(scheduler, targets, epochs)
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker    def test_chained_lr2_get_last_lr_before_step(self):
778*da0073e9SAndroid Build Coastguard Worker        schedulers = [
779*da0073e9SAndroid Build Coastguard Worker            LinearLR(self.opt, start_factor=0.4, total_iters=3),
780*da0073e9SAndroid Build Coastguard Worker            MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1),
781*da0073e9SAndroid Build Coastguard Worker        ]
782*da0073e9SAndroid Build Coastguard Worker        scheduler = ChainedScheduler(schedulers)
783*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    def test_chained_lr1(self):
786*da0073e9SAndroid Build Coastguard Worker        epochs = 10
787*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 1
788*da0073e9SAndroid Build Coastguard Worker        targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3]
789*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
790*da0073e9SAndroid Build Coastguard Worker        scheduler = ChainedScheduler(schedulers)
791*da0073e9SAndroid Build Coastguard Worker        self._test([scheduler], targets, epochs)
792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker    def test_chained_lr2(self):
795*da0073e9SAndroid Build Coastguard Worker        epochs = 10
796*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 1
797*da0073e9SAndroid Build Coastguard Worker        targets = [[0.02, 0.03, 0.04] + [0.05] * 9]
798*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
799*da0073e9SAndroid Build Coastguard Worker        scheduler = ChainedScheduler(schedulers)
800*da0073e9SAndroid Build Coastguard Worker        self._test([scheduler], targets, epochs)
801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker    def test_chained_lr3(self):
804*da0073e9SAndroid Build Coastguard Worker        epochs = 10
805*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
806*da0073e9SAndroid Build Coastguard Worker        targets = [
807*da0073e9SAndroid Build Coastguard Worker            [0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3
808*da0073e9SAndroid Build Coastguard Worker        ]
809*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
810*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
811*da0073e9SAndroid Build Coastguard Worker        scheduler = ChainedScheduler(schedulers)
812*da0073e9SAndroid Build Coastguard Worker        self._test([scheduler], targets, epochs)
813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker    def test_chained_lr4(self):
816*da0073e9SAndroid Build Coastguard Worker        epochs = 9
817*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 3
818*da0073e9SAndroid Build Coastguard Worker        targets = [
819*da0073e9SAndroid Build Coastguard Worker            [0.05 * 0.2 * 0.9**x for x in range(3)]
820*da0073e9SAndroid Build Coastguard Worker            + [0.05 * 0.2 * 0.9**3 * 0.1]
821*da0073e9SAndroid Build Coastguard Worker            + [0.05 * 0.9**x * 0.1 for x in range(4, 6)]
822*da0073e9SAndroid Build Coastguard Worker            + [0.05 * 0.9**x * 0.01 for x in range(6, 9)]
823*da0073e9SAndroid Build Coastguard Worker        ]
824*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
825*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4)
826*da0073e9SAndroid Build Coastguard Worker        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3)
827*da0073e9SAndroid Build Coastguard Worker        scheduler = ChainedScheduler(schedulers)
828*da0073e9SAndroid Build Coastguard Worker        self._test([scheduler], targets, epochs)
829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker    def test_chained_lr5(self):
832*da0073e9SAndroid Build Coastguard Worker        def poly_lr(lr: float):
833*da0073e9SAndroid Build Coastguard Worker            return [
834*da0073e9SAndroid Build Coastguard Worker                (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters)
835*da0073e9SAndroid Build Coastguard Worker            ] + [0.0] * (epochs - total_iters)
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
838*da0073e9SAndroid Build Coastguard Worker        epochs = 10
839*da0073e9SAndroid Build Coastguard Worker        power = 0.9
840*da0073e9SAndroid Build Coastguard Worker        total_iters = 5
841*da0073e9SAndroid Build Coastguard Worker        const_factor = 0.1
842*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * const_factor for x in poly_lr(lr=0.05)]
843*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]]
844*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters)
845*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ConstantLR(self.opt, factor=const_factor)
846*da0073e9SAndroid Build Coastguard Worker        scheduler = ChainedScheduler(schedulers)
847*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker    def test_compound_step_and_multistep_lr(self):
851*da0073e9SAndroid Build Coastguard Worker        epochs = 10
852*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
853*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
854*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
855*da0073e9SAndroid Build Coastguard Worker        targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]]
856*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker    def test_compound_step_and_exp_lr(self):
859*da0073e9SAndroid Build Coastguard Worker        epochs = 10
860*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
861*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05 * (0.9**x) for x in range(3)]
862*da0073e9SAndroid Build Coastguard Worker        single_targets += [0.005 * (0.9**x) for x in range(3, 6)]
863*da0073e9SAndroid Build Coastguard Worker        single_targets += [0.0005 * (0.9**x) for x in range(6, 9)]
864*da0073e9SAndroid Build Coastguard Worker        single_targets += [0.00005 * (0.9**x) for x in range(9, 12)]
865*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
866*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
867*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
868*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker    def test_compound_exp_and_multistep_lr(self):
871*da0073e9SAndroid Build Coastguard Worker        epochs = 10
872*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
873*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05 * (0.9**x) for x in range(2)]
874*da0073e9SAndroid Build Coastguard Worker        single_targets += [0.005 * (0.9**x) for x in range(2, 5)]
875*da0073e9SAndroid Build Coastguard Worker        single_targets += [0.0005 * (0.9**x) for x in range(5, 9)]
876*da0073e9SAndroid Build Coastguard Worker        single_targets += [0.00005 * (0.9**x) for x in range(9, 11)]
877*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
878*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
879*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
880*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker    def test_compound_exp_and_linearlr(self):
883*da0073e9SAndroid Build Coastguard Worker        epochs = 10
884*da0073e9SAndroid Build Coastguard Worker        iters = 4
885*da0073e9SAndroid Build Coastguard Worker        start_factor = 0.4
886*da0073e9SAndroid Build Coastguard Worker        end_factor = 0.9
887*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
888*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05 * (0.9**x) for x in range(11)]
889*da0073e9SAndroid Build Coastguard Worker        for i in range(iters):
890*da0073e9SAndroid Build Coastguard Worker            single_targets[i] *= start_factor + i / iters * (end_factor - start_factor)
891*da0073e9SAndroid Build Coastguard Worker        for i in range(iters, 11):
892*da0073e9SAndroid Build Coastguard Worker            single_targets[i] *= end_factor
893*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
894*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = LinearLR(
895*da0073e9SAndroid Build Coastguard Worker            self.opt,
896*da0073e9SAndroid Build Coastguard Worker            start_factor=start_factor,
897*da0073e9SAndroid Build Coastguard Worker            end_factor=end_factor,
898*da0073e9SAndroid Build Coastguard Worker            total_iters=iters,
899*da0073e9SAndroid Build Coastguard Worker        )
900*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
901*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker    def test_compound_step_and_constantlr(self):
904*da0073e9SAndroid Build Coastguard Worker        epochs = 10
905*da0073e9SAndroid Build Coastguard Worker        iters = 4
906*da0073e9SAndroid Build Coastguard Worker        factor = 0.4
907*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
908*da0073e9SAndroid Build Coastguard Worker        single_targets = (
909*da0073e9SAndroid Build Coastguard Worker            [0.05 * 0.4] * 3
910*da0073e9SAndroid Build Coastguard Worker            + [0.005 * 0.4]
911*da0073e9SAndroid Build Coastguard Worker            + [0.005] * 2
912*da0073e9SAndroid Build Coastguard Worker            + [0.0005] * 3
913*da0073e9SAndroid Build Coastguard Worker            + [0.00005] * 3
914*da0073e9SAndroid Build Coastguard Worker        )
915*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
916*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
917*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4)
918*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker    def test_compound_linearlr_and_multistep_lr(self):
921*da0073e9SAndroid Build Coastguard Worker        epochs = 10
922*da0073e9SAndroid Build Coastguard Worker        iters = 4
923*da0073e9SAndroid Build Coastguard Worker        start_factor = 0.4
924*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
925*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2
926*da0073e9SAndroid Build Coastguard Worker        for i in range(iters):
927*da0073e9SAndroid Build Coastguard Worker            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
928*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
929*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
930*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
931*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker    def test_compound_cosanneal_and_step_lr(self):
934*da0073e9SAndroid Build Coastguard Worker        epochs = 10
935*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
936*da0073e9SAndroid Build Coastguard Worker        single_targets = [
937*da0073e9SAndroid Build Coastguard Worker            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
938*da0073e9SAndroid Build Coastguard Worker            for x in range(epochs)
939*da0073e9SAndroid Build Coastguard Worker        ]
940*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)]
941*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
942*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
943*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
944*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
945*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker    def test_compound_cosanneal_and_multistep_lr(self):
948*da0073e9SAndroid Build Coastguard Worker        epochs = 10
949*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
950*da0073e9SAndroid Build Coastguard Worker        single_targets = [
951*da0073e9SAndroid Build Coastguard Worker            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
952*da0073e9SAndroid Build Coastguard Worker            for x in range(epochs)
953*da0073e9SAndroid Build Coastguard Worker        ]
954*da0073e9SAndroid Build Coastguard Worker        multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001]
955*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
956*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
957*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
958*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
959*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
960*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker    def test_compound_cosanneal_and_linearlr(self):
963*da0073e9SAndroid Build Coastguard Worker        epochs = 10
964*da0073e9SAndroid Build Coastguard Worker        iters = 4
965*da0073e9SAndroid Build Coastguard Worker        start_factor = 0.4
966*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
967*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
968*da0073e9SAndroid Build Coastguard Worker        single_targets = [
969*da0073e9SAndroid Build Coastguard Worker            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
970*da0073e9SAndroid Build Coastguard Worker            for x in range(epochs)
971*da0073e9SAndroid Build Coastguard Worker        ]
972*da0073e9SAndroid Build Coastguard Worker        for i in range(iters):
973*da0073e9SAndroid Build Coastguard Worker            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
974*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
975*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
976*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
977*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker    def test_compound_cosanneal_and_exp_lr(self):
980*da0073e9SAndroid Build Coastguard Worker        epochs = 10
981*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
982*da0073e9SAndroid Build Coastguard Worker        single_targets = [
983*da0073e9SAndroid Build Coastguard Worker            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
984*da0073e9SAndroid Build Coastguard Worker            for x in range(epochs)
985*da0073e9SAndroid Build Coastguard Worker        ]
986*da0073e9SAndroid Build Coastguard Worker        multipliers = [0.1**i for i in range(epochs)]
987*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
988*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets, [x * epochs for x in single_targets]]
989*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
990*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
991*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
992*da0073e9SAndroid Build Coastguard Worker        self._test(schedulers, targets, epochs)
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker    def test_compound_reduce_lr_on_plateau1(self):
995*da0073e9SAndroid Build Coastguard Worker        epochs = 10
996*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
997*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
998*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.5] * 20
999*da0073e9SAndroid Build Coastguard Worker        multipliers = [0.1 ** (i // 3) for i in range(20)]
1000*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
1001*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets]
1002*da0073e9SAndroid Build Coastguard Worker        targets = targets[1:]  # test runs step before checking lr
1003*da0073e9SAndroid Build Coastguard Worker        metrics = [10 - i * 0.0167 for i in range(20)]
1004*da0073e9SAndroid Build Coastguard Worker        schedulers = [None, None]
1005*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ReduceLROnPlateau(
1006*da0073e9SAndroid Build Coastguard Worker            self.opt,
1007*da0073e9SAndroid Build Coastguard Worker            threshold_mode="abs",
1008*da0073e9SAndroid Build Coastguard Worker            mode="min",
1009*da0073e9SAndroid Build Coastguard Worker            threshold=0.01,
1010*da0073e9SAndroid Build Coastguard Worker            patience=5,
1011*da0073e9SAndroid Build Coastguard Worker            cooldown=5,
1012*da0073e9SAndroid Build Coastguard Worker        )
1013*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
1014*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker    def test_compound_reduce_lr_on_plateau2(self):
1017*da0073e9SAndroid Build Coastguard Worker        epochs = 22
1018*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
1019*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
1020*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
1021*da0073e9SAndroid Build Coastguard Worker        multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10
1022*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
1023*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets]
1024*da0073e9SAndroid Build Coastguard Worker        targets = targets[1:]  # test runs step before checking lr
1025*da0073e9SAndroid Build Coastguard Worker        metrics = [10 - i * 0.0165 for i in range(22)]
1026*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
1027*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ReduceLROnPlateau(
1028*da0073e9SAndroid Build Coastguard Worker            self.opt,
1029*da0073e9SAndroid Build Coastguard Worker            patience=5,
1030*da0073e9SAndroid Build Coastguard Worker            cooldown=0,
1031*da0073e9SAndroid Build Coastguard Worker            threshold_mode="abs",
1032*da0073e9SAndroid Build Coastguard Worker            mode="min",
1033*da0073e9SAndroid Build Coastguard Worker            threshold=0.1,
1034*da0073e9SAndroid Build Coastguard Worker        )
1035*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12])
1036*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker    def test_compound_reduce_lr_on_plateau3(self):
1039*da0073e9SAndroid Build Coastguard Worker        epochs = 22
1040*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
1041*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
1042*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4
1043*da0073e9SAndroid Build Coastguard Worker        multipliers = [0.1**i for i in range(epochs)]
1044*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
1045*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets]
1046*da0073e9SAndroid Build Coastguard Worker        targets = targets[1:]  # test runs step before checking lr
1047*da0073e9SAndroid Build Coastguard Worker        metrics = [-0.8] * 2 + [-0.234] * 20
1048*da0073e9SAndroid Build Coastguard Worker        schedulers = [None, None]
1049*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ReduceLROnPlateau(
1050*da0073e9SAndroid Build Coastguard Worker            self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
1051*da0073e9SAndroid Build Coastguard Worker        )
1052*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
1053*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker    def test_compound_reduce_lr_on_plateau4(self):
1056*da0073e9SAndroid Build Coastguard Worker        epochs = 20
1057*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
1058*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.05
1059*da0073e9SAndroid Build Coastguard Worker        epochs = 10
1060*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
1061*da0073e9SAndroid Build Coastguard Worker        single_targets = [
1062*da0073e9SAndroid Build Coastguard Worker            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
1063*da0073e9SAndroid Build Coastguard Worker            for x in range(epochs)
1064*da0073e9SAndroid Build Coastguard Worker        ]
1065*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets]
1066*da0073e9SAndroid Build Coastguard Worker        targets = targets[1:]  # test runs step before checking lr
1067*da0073e9SAndroid Build Coastguard Worker        metrics = [1.5 * (1.025**i) for i in range(20)]  # 1.025 > 1.1**0.25
1068*da0073e9SAndroid Build Coastguard Worker        schedulers = [None, None]
1069*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ReduceLROnPlateau(
1070*da0073e9SAndroid Build Coastguard Worker            self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
1071*da0073e9SAndroid Build Coastguard Worker        )
1072*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
1073*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker    def test_compound_reduce_lr_on_plateau5(self):
1076*da0073e9SAndroid Build Coastguard Worker        iters = 4
1077*da0073e9SAndroid Build Coastguard Worker        start_factor = 0.4
1078*da0073e9SAndroid Build Coastguard Worker        epochs = 22
1079*da0073e9SAndroid Build Coastguard Worker        for param_group in self.opt.param_groups:
1080*da0073e9SAndroid Build Coastguard Worker            param_group["lr"] = 0.5
1081*da0073e9SAndroid Build Coastguard Worker        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
1082*da0073e9SAndroid Build Coastguard Worker        multipliers = [1] * 22
1083*da0073e9SAndroid Build Coastguard Worker        for i in range(iters):
1084*da0073e9SAndroid Build Coastguard Worker            multipliers[i] *= start_factor + i / iters * (1 - start_factor)
1085*da0073e9SAndroid Build Coastguard Worker        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
1086*da0073e9SAndroid Build Coastguard Worker        targets = [single_targets]
1087*da0073e9SAndroid Build Coastguard Worker        targets = targets[1:]  # test runs step before checking lr
1088*da0073e9SAndroid Build Coastguard Worker        metrics = [10 - i * 0.0165 for i in range(22)]
1089*da0073e9SAndroid Build Coastguard Worker        schedulers = [None] * 2
1090*da0073e9SAndroid Build Coastguard Worker        schedulers[0] = ReduceLROnPlateau(
1091*da0073e9SAndroid Build Coastguard Worker            self.opt,
1092*da0073e9SAndroid Build Coastguard Worker            patience=5,
1093*da0073e9SAndroid Build Coastguard Worker            cooldown=0,
1094*da0073e9SAndroid Build Coastguard Worker            threshold_mode="abs",
1095*da0073e9SAndroid Build Coastguard Worker            mode="min",
1096*da0073e9SAndroid Build Coastguard Worker            threshold=0.1,
1097*da0073e9SAndroid Build Coastguard Worker        )
1098*da0073e9SAndroid Build Coastguard Worker        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
1099*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_invalid_mode(self):
1102*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1103*da0073e9SAndroid Build Coastguard Worker            scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS")
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_triangular_mode_one_lr(self):
1106*da0073e9SAndroid Build Coastguard Worker        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1107*da0073e9SAndroid Build Coastguard Worker        momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
1108*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1109*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1110*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1111*da0073e9SAndroid Build Coastguard Worker            self.opt,
1112*da0073e9SAndroid Build Coastguard Worker            base_lr=1,
1113*da0073e9SAndroid Build Coastguard Worker            max_lr=5,
1114*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1115*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1116*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1117*da0073e9SAndroid Build Coastguard Worker            max_momentum=5,
1118*da0073e9SAndroid Build Coastguard Worker            mode="triangular",
1119*da0073e9SAndroid Build Coastguard Worker        )
1120*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_triangular_mode_one_lr_no_momentum(self):
1123*da0073e9SAndroid Build Coastguard Worker        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1124*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1125*da0073e9SAndroid Build Coastguard Worker        momentum_target = [self.opt.defaults["momentum"]] * len(lr_target)
1126*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1127*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1128*da0073e9SAndroid Build Coastguard Worker            self.opt,
1129*da0073e9SAndroid Build Coastguard Worker            base_lr=1,
1130*da0073e9SAndroid Build Coastguard Worker            max_lr=5,
1131*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1132*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=False,
1133*da0073e9SAndroid Build Coastguard Worker            mode="triangular",
1134*da0073e9SAndroid Build Coastguard Worker        )
1135*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1136*da0073e9SAndroid Build Coastguard Worker
1137*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_triangular2_mode_one_lr(self):
1138*da0073e9SAndroid Build Coastguard Worker        lr_target = [
1139*da0073e9SAndroid Build Coastguard Worker            1,
1140*da0073e9SAndroid Build Coastguard Worker            2,
1141*da0073e9SAndroid Build Coastguard Worker            3,
1142*da0073e9SAndroid Build Coastguard Worker            4,
1143*da0073e9SAndroid Build Coastguard Worker            5,
1144*da0073e9SAndroid Build Coastguard Worker            4,
1145*da0073e9SAndroid Build Coastguard Worker            3,
1146*da0073e9SAndroid Build Coastguard Worker            2,
1147*da0073e9SAndroid Build Coastguard Worker            1,
1148*da0073e9SAndroid Build Coastguard Worker            1.5,
1149*da0073e9SAndroid Build Coastguard Worker            2.0,
1150*da0073e9SAndroid Build Coastguard Worker            2.5,
1151*da0073e9SAndroid Build Coastguard Worker            3.0,
1152*da0073e9SAndroid Build Coastguard Worker            2.5,
1153*da0073e9SAndroid Build Coastguard Worker            2.0,
1154*da0073e9SAndroid Build Coastguard Worker            1.5,
1155*da0073e9SAndroid Build Coastguard Worker            1,
1156*da0073e9SAndroid Build Coastguard Worker            1.25,
1157*da0073e9SAndroid Build Coastguard Worker            1.50,
1158*da0073e9SAndroid Build Coastguard Worker            1.75,
1159*da0073e9SAndroid Build Coastguard Worker            2.00,
1160*da0073e9SAndroid Build Coastguard Worker            1.75,
1161*da0073e9SAndroid Build Coastguard Worker        ]
1162*da0073e9SAndroid Build Coastguard Worker        momentum_target = [
1163*da0073e9SAndroid Build Coastguard Worker            5.0,
1164*da0073e9SAndroid Build Coastguard Worker            4.0,
1165*da0073e9SAndroid Build Coastguard Worker            3.0,
1166*da0073e9SAndroid Build Coastguard Worker            2.0,
1167*da0073e9SAndroid Build Coastguard Worker            1.0,
1168*da0073e9SAndroid Build Coastguard Worker            2.0,
1169*da0073e9SAndroid Build Coastguard Worker            3.0,
1170*da0073e9SAndroid Build Coastguard Worker            4.0,
1171*da0073e9SAndroid Build Coastguard Worker            5.0,
1172*da0073e9SAndroid Build Coastguard Worker            4.5,
1173*da0073e9SAndroid Build Coastguard Worker            4.0,
1174*da0073e9SAndroid Build Coastguard Worker            3.5,
1175*da0073e9SAndroid Build Coastguard Worker            3.0,
1176*da0073e9SAndroid Build Coastguard Worker            3.5,
1177*da0073e9SAndroid Build Coastguard Worker            4.0,
1178*da0073e9SAndroid Build Coastguard Worker            4.5,
1179*da0073e9SAndroid Build Coastguard Worker            5.0,
1180*da0073e9SAndroid Build Coastguard Worker            4.75,
1181*da0073e9SAndroid Build Coastguard Worker            4.5,
1182*da0073e9SAndroid Build Coastguard Worker            4.25,
1183*da0073e9SAndroid Build Coastguard Worker            4.0,
1184*da0073e9SAndroid Build Coastguard Worker            4.25,
1185*da0073e9SAndroid Build Coastguard Worker        ]
1186*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1187*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1188*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1189*da0073e9SAndroid Build Coastguard Worker            self.opt,
1190*da0073e9SAndroid Build Coastguard Worker            base_lr=1,
1191*da0073e9SAndroid Build Coastguard Worker            max_lr=5,
1192*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1193*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1194*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1195*da0073e9SAndroid Build Coastguard Worker            max_momentum=5,
1196*da0073e9SAndroid Build Coastguard Worker            mode="triangular2",
1197*da0073e9SAndroid Build Coastguard Worker        )
1198*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_exp_range_mode_one_lr(self):
1201*da0073e9SAndroid Build Coastguard Worker        base_lr, max_lr = 1, 5
1202*da0073e9SAndroid Build Coastguard Worker        diff_lr = max_lr - base_lr
1203*da0073e9SAndroid Build Coastguard Worker        gamma = 0.9
1204*da0073e9SAndroid Build Coastguard Worker        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
1205*da0073e9SAndroid Build Coastguard Worker        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
1206*da0073e9SAndroid Build Coastguard Worker        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
1207*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1208*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1209*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1210*da0073e9SAndroid Build Coastguard Worker            self.opt,
1211*da0073e9SAndroid Build Coastguard Worker            base_lr=base_lr,
1212*da0073e9SAndroid Build Coastguard Worker            max_lr=max_lr,
1213*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1214*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1215*da0073e9SAndroid Build Coastguard Worker            base_momentum=base_lr,
1216*da0073e9SAndroid Build Coastguard Worker            max_momentum=max_lr,
1217*da0073e9SAndroid Build Coastguard Worker            mode="exp_range",
1218*da0073e9SAndroid Build Coastguard Worker            gamma=gamma,
1219*da0073e9SAndroid Build Coastguard Worker        )
1220*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1221*da0073e9SAndroid Build Coastguard Worker
1222*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_triangular_mode(self):
1223*da0073e9SAndroid Build Coastguard Worker        lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1224*da0073e9SAndroid Build Coastguard Worker        lr_target_2 = [x + 1 for x in lr_target_1]
1225*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target_1, lr_target_2]
1226*da0073e9SAndroid Build Coastguard Worker        momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
1227*da0073e9SAndroid Build Coastguard Worker        momentum_target_2 = [x + 1 for x in momentum_target_1]
1228*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target_1, momentum_target_2]
1229*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1230*da0073e9SAndroid Build Coastguard Worker            self.opt,
1231*da0073e9SAndroid Build Coastguard Worker            base_lr=[1, 2],
1232*da0073e9SAndroid Build Coastguard Worker            max_lr=[5, 6],
1233*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1234*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1235*da0073e9SAndroid Build Coastguard Worker            base_momentum=[1, 2],
1236*da0073e9SAndroid Build Coastguard Worker            max_momentum=[5, 6],
1237*da0073e9SAndroid Build Coastguard Worker            mode="triangular",
1238*da0073e9SAndroid Build Coastguard Worker        )
1239*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_triangular2_mode(self):
1242*da0073e9SAndroid Build Coastguard Worker        lr_target_1 = [
1243*da0073e9SAndroid Build Coastguard Worker            1,
1244*da0073e9SAndroid Build Coastguard Worker            2,
1245*da0073e9SAndroid Build Coastguard Worker            3,
1246*da0073e9SAndroid Build Coastguard Worker            4,
1247*da0073e9SAndroid Build Coastguard Worker            5,
1248*da0073e9SAndroid Build Coastguard Worker            4,
1249*da0073e9SAndroid Build Coastguard Worker            3,
1250*da0073e9SAndroid Build Coastguard Worker            2,
1251*da0073e9SAndroid Build Coastguard Worker            1,
1252*da0073e9SAndroid Build Coastguard Worker            1.5,
1253*da0073e9SAndroid Build Coastguard Worker            2.0,
1254*da0073e9SAndroid Build Coastguard Worker            2.5,
1255*da0073e9SAndroid Build Coastguard Worker            3.0,
1256*da0073e9SAndroid Build Coastguard Worker            2.5,
1257*da0073e9SAndroid Build Coastguard Worker            2.0,
1258*da0073e9SAndroid Build Coastguard Worker            1.5,
1259*da0073e9SAndroid Build Coastguard Worker            1,
1260*da0073e9SAndroid Build Coastguard Worker            1.25,
1261*da0073e9SAndroid Build Coastguard Worker            1.50,
1262*da0073e9SAndroid Build Coastguard Worker            1.75,
1263*da0073e9SAndroid Build Coastguard Worker            2.00,
1264*da0073e9SAndroid Build Coastguard Worker            1.75,
1265*da0073e9SAndroid Build Coastguard Worker        ]
1266*da0073e9SAndroid Build Coastguard Worker        lr_target_2 = [x + 2 for x in lr_target_1]
1267*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target_1, lr_target_2]
1268*da0073e9SAndroid Build Coastguard Worker        momentum_target_1 = [
1269*da0073e9SAndroid Build Coastguard Worker            5.0,
1270*da0073e9SAndroid Build Coastguard Worker            4.0,
1271*da0073e9SAndroid Build Coastguard Worker            3.0,
1272*da0073e9SAndroid Build Coastguard Worker            2.0,
1273*da0073e9SAndroid Build Coastguard Worker            1.0,
1274*da0073e9SAndroid Build Coastguard Worker            2.0,
1275*da0073e9SAndroid Build Coastguard Worker            3.0,
1276*da0073e9SAndroid Build Coastguard Worker            4.0,
1277*da0073e9SAndroid Build Coastguard Worker            5.0,
1278*da0073e9SAndroid Build Coastguard Worker            4.5,
1279*da0073e9SAndroid Build Coastguard Worker            4.0,
1280*da0073e9SAndroid Build Coastguard Worker            3.5,
1281*da0073e9SAndroid Build Coastguard Worker            3.0,
1282*da0073e9SAndroid Build Coastguard Worker            3.5,
1283*da0073e9SAndroid Build Coastguard Worker            4.0,
1284*da0073e9SAndroid Build Coastguard Worker            4.5,
1285*da0073e9SAndroid Build Coastguard Worker            5.0,
1286*da0073e9SAndroid Build Coastguard Worker            4.75,
1287*da0073e9SAndroid Build Coastguard Worker            4.5,
1288*da0073e9SAndroid Build Coastguard Worker            4.25,
1289*da0073e9SAndroid Build Coastguard Worker            4.0,
1290*da0073e9SAndroid Build Coastguard Worker            4.25,
1291*da0073e9SAndroid Build Coastguard Worker        ]
1292*da0073e9SAndroid Build Coastguard Worker        momentum_target_2 = [x + 2 for x in momentum_target_1]
1293*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target_1, momentum_target_2]
1294*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1295*da0073e9SAndroid Build Coastguard Worker            self.opt,
1296*da0073e9SAndroid Build Coastguard Worker            base_lr=[1, 3],
1297*da0073e9SAndroid Build Coastguard Worker            max_lr=[5, 7],
1298*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1299*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1300*da0073e9SAndroid Build Coastguard Worker            base_momentum=[1, 3],
1301*da0073e9SAndroid Build Coastguard Worker            max_momentum=[5, 7],
1302*da0073e9SAndroid Build Coastguard Worker            mode="triangular2",
1303*da0073e9SAndroid Build Coastguard Worker        )
1304*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_exp_range_mode(self):
1307*da0073e9SAndroid Build Coastguard Worker        base_lr_1, max_lr_1 = 1, 5
1308*da0073e9SAndroid Build Coastguard Worker        base_lr_2, max_lr_2 = 5, 12
1309*da0073e9SAndroid Build Coastguard Worker
1310*da0073e9SAndroid Build Coastguard Worker        diff_lr_1 = max_lr_1 - base_lr_1
1311*da0073e9SAndroid Build Coastguard Worker        diff_lr_2 = max_lr_2 - base_lr_2
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker        gamma = 0.9
1314*da0073e9SAndroid Build Coastguard Worker        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
1315*da0073e9SAndroid Build Coastguard Worker        lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
1316*da0073e9SAndroid Build Coastguard Worker        lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
1317*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target_1, lr_target_2]
1318*da0073e9SAndroid Build Coastguard Worker        momentum_target_1 = [
1319*da0073e9SAndroid Build Coastguard Worker            max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)
1320*da0073e9SAndroid Build Coastguard Worker        ]
1321*da0073e9SAndroid Build Coastguard Worker        momentum_target_2 = [
1322*da0073e9SAndroid Build Coastguard Worker            max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)
1323*da0073e9SAndroid Build Coastguard Worker        ]
1324*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target_1, momentum_target_2]
1325*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1326*da0073e9SAndroid Build Coastguard Worker            self.opt,
1327*da0073e9SAndroid Build Coastguard Worker            base_lr=[base_lr_1, base_lr_2],
1328*da0073e9SAndroid Build Coastguard Worker            max_lr=[max_lr_1, max_lr_2],
1329*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1330*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1331*da0073e9SAndroid Build Coastguard Worker            base_momentum=[base_lr_1, base_lr_2],
1332*da0073e9SAndroid Build Coastguard Worker            max_momentum=[max_lr_1, max_lr_2],
1333*da0073e9SAndroid Build Coastguard Worker            mode="exp_range",
1334*da0073e9SAndroid Build Coastguard Worker            gamma=gamma,
1335*da0073e9SAndroid Build Coastguard Worker        )
1336*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_triangular_mode_step_size_up_down(self):
1339*da0073e9SAndroid Build Coastguard Worker        lr_target = [
1340*da0073e9SAndroid Build Coastguard Worker            1.0,
1341*da0073e9SAndroid Build Coastguard Worker            2.0,
1342*da0073e9SAndroid Build Coastguard Worker            3.0,
1343*da0073e9SAndroid Build Coastguard Worker            4.0,
1344*da0073e9SAndroid Build Coastguard Worker            5.0,
1345*da0073e9SAndroid Build Coastguard Worker            13.0 / 3,
1346*da0073e9SAndroid Build Coastguard Worker            11.0 / 3,
1347*da0073e9SAndroid Build Coastguard Worker            9.0 / 3,
1348*da0073e9SAndroid Build Coastguard Worker            7.0 / 3,
1349*da0073e9SAndroid Build Coastguard Worker            5.0 / 3,
1350*da0073e9SAndroid Build Coastguard Worker            1.0,
1351*da0073e9SAndroid Build Coastguard Worker        ]
1352*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1353*da0073e9SAndroid Build Coastguard Worker        momentum_target = [
1354*da0073e9SAndroid Build Coastguard Worker            5.0,
1355*da0073e9SAndroid Build Coastguard Worker            4.0,
1356*da0073e9SAndroid Build Coastguard Worker            3.0,
1357*da0073e9SAndroid Build Coastguard Worker            2.0,
1358*da0073e9SAndroid Build Coastguard Worker            1.0,
1359*da0073e9SAndroid Build Coastguard Worker            5.0 / 3,
1360*da0073e9SAndroid Build Coastguard Worker            7.0 / 3,
1361*da0073e9SAndroid Build Coastguard Worker            3.0,
1362*da0073e9SAndroid Build Coastguard Worker            11.0 / 3,
1363*da0073e9SAndroid Build Coastguard Worker            13.0 / 3,
1364*da0073e9SAndroid Build Coastguard Worker            5.0,
1365*da0073e9SAndroid Build Coastguard Worker        ]
1366*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1369*da0073e9SAndroid Build Coastguard Worker            self.opt,
1370*da0073e9SAndroid Build Coastguard Worker            base_lr=1,
1371*da0073e9SAndroid Build Coastguard Worker            max_lr=5,
1372*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1373*da0073e9SAndroid Build Coastguard Worker            step_size_down=6,
1374*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1375*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1376*da0073e9SAndroid Build Coastguard Worker            max_momentum=5,
1377*da0073e9SAndroid Build Coastguard Worker            mode="triangular",
1378*da0073e9SAndroid Build Coastguard Worker        )
1379*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_triangular2_mode_step_size_up_down(self):
1382*da0073e9SAndroid Build Coastguard Worker        lr_base_target = [
1383*da0073e9SAndroid Build Coastguard Worker            1.0,
1384*da0073e9SAndroid Build Coastguard Worker            3.0,
1385*da0073e9SAndroid Build Coastguard Worker            5.0,
1386*da0073e9SAndroid Build Coastguard Worker            13.0 / 3,
1387*da0073e9SAndroid Build Coastguard Worker            11.0 / 3,
1388*da0073e9SAndroid Build Coastguard Worker            9.0 / 3,
1389*da0073e9SAndroid Build Coastguard Worker            7.0 / 3,
1390*da0073e9SAndroid Build Coastguard Worker            5.0 / 3,
1391*da0073e9SAndroid Build Coastguard Worker            1.0,
1392*da0073e9SAndroid Build Coastguard Worker            2.0,
1393*da0073e9SAndroid Build Coastguard Worker            3.0,
1394*da0073e9SAndroid Build Coastguard Worker            8.0 / 3,
1395*da0073e9SAndroid Build Coastguard Worker            7.0 / 3,
1396*da0073e9SAndroid Build Coastguard Worker            6.0 / 3,
1397*da0073e9SAndroid Build Coastguard Worker            5.0 / 3,
1398*da0073e9SAndroid Build Coastguard Worker            4.0 / 3,
1399*da0073e9SAndroid Build Coastguard Worker            1.0,
1400*da0073e9SAndroid Build Coastguard Worker            3.0 / 2,
1401*da0073e9SAndroid Build Coastguard Worker            2.0,
1402*da0073e9SAndroid Build Coastguard Worker            11.0 / 6,
1403*da0073e9SAndroid Build Coastguard Worker            10.0 / 6,
1404*da0073e9SAndroid Build Coastguard Worker            9.0 / 6,
1405*da0073e9SAndroid Build Coastguard Worker            8.0 / 6,
1406*da0073e9SAndroid Build Coastguard Worker            7.0 / 6,
1407*da0073e9SAndroid Build Coastguard Worker        ]
1408*da0073e9SAndroid Build Coastguard Worker        momentum_base_target = [
1409*da0073e9SAndroid Build Coastguard Worker            5.0,
1410*da0073e9SAndroid Build Coastguard Worker            3.0,
1411*da0073e9SAndroid Build Coastguard Worker            1.0,
1412*da0073e9SAndroid Build Coastguard Worker            5.0 / 3,
1413*da0073e9SAndroid Build Coastguard Worker            7.0 / 3,
1414*da0073e9SAndroid Build Coastguard Worker            3.0,
1415*da0073e9SAndroid Build Coastguard Worker            11.0 / 3,
1416*da0073e9SAndroid Build Coastguard Worker            13.0 / 3,
1417*da0073e9SAndroid Build Coastguard Worker            5.0,
1418*da0073e9SAndroid Build Coastguard Worker            4.0,
1419*da0073e9SAndroid Build Coastguard Worker            3.0,
1420*da0073e9SAndroid Build Coastguard Worker            10.0 / 3,
1421*da0073e9SAndroid Build Coastguard Worker            11.0 / 3,
1422*da0073e9SAndroid Build Coastguard Worker            4.0,
1423*da0073e9SAndroid Build Coastguard Worker            13.0 / 3,
1424*da0073e9SAndroid Build Coastguard Worker            14.0 / 3,
1425*da0073e9SAndroid Build Coastguard Worker            5.0,
1426*da0073e9SAndroid Build Coastguard Worker            4.5,
1427*da0073e9SAndroid Build Coastguard Worker            4.0,
1428*da0073e9SAndroid Build Coastguard Worker            25.0 / 6,
1429*da0073e9SAndroid Build Coastguard Worker            13.0 / 3,
1430*da0073e9SAndroid Build Coastguard Worker            4.5,
1431*da0073e9SAndroid Build Coastguard Worker            14.0 / 3,
1432*da0073e9SAndroid Build Coastguard Worker            29.0 / 6,
1433*da0073e9SAndroid Build Coastguard Worker        ]
1434*da0073e9SAndroid Build Coastguard Worker        deltas = [2 * i for i in range(0, 2)]
1435*da0073e9SAndroid Build Coastguard Worker        base_lrs = [1 + delta for delta in deltas]
1436*da0073e9SAndroid Build Coastguard Worker        max_lrs = [5 + delta for delta in deltas]
1437*da0073e9SAndroid Build Coastguard Worker        lr_targets = [[x + delta for x in lr_base_target] for delta in deltas]
1438*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [
1439*da0073e9SAndroid Build Coastguard Worker            [x + delta for x in momentum_base_target] for delta in deltas
1440*da0073e9SAndroid Build Coastguard Worker        ]
1441*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1442*da0073e9SAndroid Build Coastguard Worker            self.opt,
1443*da0073e9SAndroid Build Coastguard Worker            base_lr=base_lrs,
1444*da0073e9SAndroid Build Coastguard Worker            max_lr=max_lrs,
1445*da0073e9SAndroid Build Coastguard Worker            step_size_up=2,
1446*da0073e9SAndroid Build Coastguard Worker            step_size_down=6,
1447*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1448*da0073e9SAndroid Build Coastguard Worker            base_momentum=base_lrs,
1449*da0073e9SAndroid Build Coastguard Worker            max_momentum=max_lrs,
1450*da0073e9SAndroid Build Coastguard Worker            mode="triangular2",
1451*da0073e9SAndroid Build Coastguard Worker        )
1452*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(
1453*da0073e9SAndroid Build Coastguard Worker            scheduler, lr_targets, momentum_targets, len(lr_base_target)
1454*da0073e9SAndroid Build Coastguard Worker        )
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_exp_range_mode_step_size_up_down(self):
1457*da0073e9SAndroid Build Coastguard Worker        base_lr, max_lr = 1, 5
1458*da0073e9SAndroid Build Coastguard Worker        diff_lr = max_lr - base_lr
1459*da0073e9SAndroid Build Coastguard Worker        gamma = 0.9
1460*da0073e9SAndroid Build Coastguard Worker        xs = [
1461*da0073e9SAndroid Build Coastguard Worker            0.0,
1462*da0073e9SAndroid Build Coastguard Worker            0.5,
1463*da0073e9SAndroid Build Coastguard Worker            1.0,
1464*da0073e9SAndroid Build Coastguard Worker            5.0 / 6,
1465*da0073e9SAndroid Build Coastguard Worker            4.0 / 6,
1466*da0073e9SAndroid Build Coastguard Worker            3.0 / 6,
1467*da0073e9SAndroid Build Coastguard Worker            2.0 / 6,
1468*da0073e9SAndroid Build Coastguard Worker            1.0 / 6,
1469*da0073e9SAndroid Build Coastguard Worker            0.0,
1470*da0073e9SAndroid Build Coastguard Worker            0.5,
1471*da0073e9SAndroid Build Coastguard Worker            1.0,
1472*da0073e9SAndroid Build Coastguard Worker            5.0 / 6,
1473*da0073e9SAndroid Build Coastguard Worker            4.0 / 6,
1474*da0073e9SAndroid Build Coastguard Worker        ]
1475*da0073e9SAndroid Build Coastguard Worker        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
1476*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1477*da0073e9SAndroid Build Coastguard Worker        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
1478*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1479*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1480*da0073e9SAndroid Build Coastguard Worker            self.opt,
1481*da0073e9SAndroid Build Coastguard Worker            base_lr=base_lr,
1482*da0073e9SAndroid Build Coastguard Worker            max_lr=max_lr,
1483*da0073e9SAndroid Build Coastguard Worker            step_size_up=2,
1484*da0073e9SAndroid Build Coastguard Worker            step_size_down=6,
1485*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=True,
1486*da0073e9SAndroid Build Coastguard Worker            base_momentum=base_lr,
1487*da0073e9SAndroid Build Coastguard Worker            max_momentum=max_lr,
1488*da0073e9SAndroid Build Coastguard Worker            mode="exp_range",
1489*da0073e9SAndroid Build Coastguard Worker            gamma=gamma,
1490*da0073e9SAndroid Build Coastguard Worker        )
1491*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_with_momentumless_optimizer(self):
1494*da0073e9SAndroid Build Coastguard Worker        # Note [Temporarily set optimizer to Adam]
1495*da0073e9SAndroid Build Coastguard Worker        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1496*da0073e9SAndroid Build Coastguard Worker        # The TestLRScheduler object carries around an SGD optimizer to avoid having to
1497*da0073e9SAndroid Build Coastguard Worker        # instantiate one for every test. This gets in the way for our very specific case
1498*da0073e9SAndroid Build Coastguard Worker        # in which we need to use Adam (or really any optimizer that doesn't use momentum)
1499*da0073e9SAndroid Build Coastguard Worker        # in order to test that the momentum bug in CyclicLR is fixed (the bug is described
1500*da0073e9SAndroid Build Coastguard Worker        # in more detail in https://github.com/pytorch/pytorch/issues/19003 ).
1501*da0073e9SAndroid Build Coastguard Worker        old_opt = self.opt
1502*da0073e9SAndroid Build Coastguard Worker        self.opt = Adam(
1503*da0073e9SAndroid Build Coastguard Worker            [
1504*da0073e9SAndroid Build Coastguard Worker                {"params": self.net.conv1.parameters()},
1505*da0073e9SAndroid Build Coastguard Worker                {"params": self.net.conv2.parameters(), "lr": 0.5},
1506*da0073e9SAndroid Build Coastguard Worker            ],
1507*da0073e9SAndroid Build Coastguard Worker            lr=0.05,
1508*da0073e9SAndroid Build Coastguard Worker        )
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1511*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1512*da0073e9SAndroid Build Coastguard Worker        momentum_target = [None] * len(lr_target)
1513*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1514*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1515*da0073e9SAndroid Build Coastguard Worker            self.opt,
1516*da0073e9SAndroid Build Coastguard Worker            base_lr=1,
1517*da0073e9SAndroid Build Coastguard Worker            max_lr=5,
1518*da0073e9SAndroid Build Coastguard Worker            step_size_up=4,
1519*da0073e9SAndroid Build Coastguard Worker            cycle_momentum=False,
1520*da0073e9SAndroid Build Coastguard Worker            mode="triangular",
1521*da0073e9SAndroid Build Coastguard Worker        )
1522*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        self.opt = old_opt  # set optimizer back to SGD
1525*da0073e9SAndroid Build Coastguard Worker
1526*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
1527*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1528*da0073e9SAndroid Build Coastguard Worker            rprop_opt = Rprop(self.net.parameters())
1529*da0073e9SAndroid Build Coastguard Worker            scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)
1530*da0073e9SAndroid Build Coastguard Worker
1531*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
1532*da0073e9SAndroid Build Coastguard Worker        adam_opt = Adam(self.net.parameters())
1533*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
1534*da0073e9SAndroid Build Coastguard Worker
1535*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_removed_after_out_of_scope(self):
1536*da0073e9SAndroid Build Coastguard Worker        import gc
1537*da0073e9SAndroid Build Coastguard Worker        import weakref
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker        gc.disable()
1540*da0073e9SAndroid Build Coastguard Worker
1541*da0073e9SAndroid Build Coastguard Worker        def test():
1542*da0073e9SAndroid Build Coastguard Worker            adam_opt = Adam(self.net.parameters())
1543*da0073e9SAndroid Build Coastguard Worker            scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
1544*da0073e9SAndroid Build Coastguard Worker            return weakref.ref(scheduler)
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker        ref = test()
1547*da0073e9SAndroid Build Coastguard Worker        assert ref() is None
1548*da0073e9SAndroid Build Coastguard Worker        gc.enable()
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_state_dict_picklable(self):
1551*da0073e9SAndroid Build Coastguard Worker        adam_opt = Adam(self.net.parameters())
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker        # Case 1: Built-in mode
1554*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
1555*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType)
1556*da0073e9SAndroid Build Coastguard Worker        state = scheduler.state_dict()
1557*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("_scale_fn_ref", state)
1558*da0073e9SAndroid Build Coastguard Worker        self.assertIs(state["_scale_fn_custom"], None)
1559*da0073e9SAndroid Build Coastguard Worker        pickle.dumps(state)
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker        # Case 2: Custom `scale_fn`, a function object
1562*da0073e9SAndroid Build Coastguard Worker        def scale_fn(_):
1563*da0073e9SAndroid Build Coastguard Worker            return 0.5
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1566*da0073e9SAndroid Build Coastguard Worker            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1567*da0073e9SAndroid Build Coastguard Worker        )
1568*da0073e9SAndroid Build Coastguard Worker        state = scheduler.state_dict()
1569*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("_scale_fn_ref", state)
1570*da0073e9SAndroid Build Coastguard Worker        self.assertIs(state["_scale_fn_custom"], None)
1571*da0073e9SAndroid Build Coastguard Worker        pickle.dumps(state)
1572*da0073e9SAndroid Build Coastguard Worker
1573*da0073e9SAndroid Build Coastguard Worker        # Case 3: Custom `scale_fn`, a callable class
1574*da0073e9SAndroid Build Coastguard Worker        class ScaleFn:
1575*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1576*da0073e9SAndroid Build Coastguard Worker                self.x = 0.5
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker            def __call__(self, _):
1579*da0073e9SAndroid Build Coastguard Worker                return self.x
1580*da0073e9SAndroid Build Coastguard Worker
1581*da0073e9SAndroid Build Coastguard Worker        scale_fn = ScaleFn()
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1584*da0073e9SAndroid Build Coastguard Worker            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1585*da0073e9SAndroid Build Coastguard Worker        )
1586*da0073e9SAndroid Build Coastguard Worker        state = scheduler.state_dict()
1587*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("_scale_fn_ref", state)
1588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__)
1589*da0073e9SAndroid Build Coastguard Worker        pickle.dumps(state)
1590*da0073e9SAndroid Build Coastguard Worker
1591*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_scale_fn_restored_from_state_dict(self):
1592*da0073e9SAndroid Build Coastguard Worker        adam_opt = Adam(self.net.parameters())
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker        # Case 1: Built-in mode
1595*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1596*da0073e9SAndroid Build Coastguard Worker            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, mode="triangular2"
1597*da0073e9SAndroid Build Coastguard Worker        )
1598*da0073e9SAndroid Build Coastguard Worker        restored_scheduler = CyclicLR(
1599*da0073e9SAndroid Build Coastguard Worker            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False
1600*da0073e9SAndroid Build Coastguard Worker        )
1601*da0073e9SAndroid Build Coastguard Worker        restored_scheduler.load_state_dict(scheduler.state_dict())
1602*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(restored_scheduler.mode == scheduler.mode == "triangular2")
1603*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(restored_scheduler._scale_fn_ref) and self.assertIsNotNone(
1604*da0073e9SAndroid Build Coastguard Worker            scheduler._scale_fn_ref
1605*da0073e9SAndroid Build Coastguard Worker        )
1606*da0073e9SAndroid Build Coastguard Worker        self.assertIs(restored_scheduler._scale_fn_custom, None)
1607*da0073e9SAndroid Build Coastguard Worker        self.assertIs(scheduler._scale_fn_custom, None)
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker        # Case 2: Custom `scale_fn`
1610*da0073e9SAndroid Build Coastguard Worker        def scale_fn(_):
1611*da0073e9SAndroid Build Coastguard Worker            return 0.5
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker        scheduler = CyclicLR(
1614*da0073e9SAndroid Build Coastguard Worker            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1615*da0073e9SAndroid Build Coastguard Worker        )
1616*da0073e9SAndroid Build Coastguard Worker        restored_scheduler = CyclicLR(
1617*da0073e9SAndroid Build Coastguard Worker            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1618*da0073e9SAndroid Build Coastguard Worker        )
1619*da0073e9SAndroid Build Coastguard Worker        restored_scheduler.load_state_dict(scheduler.state_dict())
1620*da0073e9SAndroid Build Coastguard Worker        self.assertIs(scheduler._scale_fn_custom, scale_fn)
1621*da0073e9SAndroid Build Coastguard Worker        self.assertIs(restored_scheduler._scale_fn_custom, scale_fn)
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker    def test_onecycle_lr_invalid_anneal_strategy(self):
1624*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1625*da0073e9SAndroid Build Coastguard Worker            scheduler = OneCycleLR(
1626*da0073e9SAndroid Build Coastguard Worker                self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS"
1627*da0073e9SAndroid Build Coastguard Worker            )
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker    def test_onecycle_lr_invalid_pct_start(self):
1630*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1631*da0073e9SAndroid Build Coastguard Worker            scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker    def test_onecycle_lr_cannot_calculate_total_steps(self):
1634*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1635*da0073e9SAndroid Build Coastguard Worker            scheduler = OneCycleLR(self.opt, max_lr=1e-3)
1636*da0073e9SAndroid Build Coastguard Worker
1637*da0073e9SAndroid Build Coastguard Worker    def test_onecycle_lr_linear_annealing(self):
1638*da0073e9SAndroid Build Coastguard Worker        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
1639*da0073e9SAndroid Build Coastguard Worker        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
1640*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1641*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1642*da0073e9SAndroid Build Coastguard Worker        scheduler = OneCycleLR(
1643*da0073e9SAndroid Build Coastguard Worker            self.opt,
1644*da0073e9SAndroid Build Coastguard Worker            max_lr=25,
1645*da0073e9SAndroid Build Coastguard Worker            final_div_factor=2,
1646*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1647*da0073e9SAndroid Build Coastguard Worker            max_momentum=22,
1648*da0073e9SAndroid Build Coastguard Worker            total_steps=10,
1649*da0073e9SAndroid Build Coastguard Worker            anneal_strategy="linear",
1650*da0073e9SAndroid Build Coastguard Worker        )
1651*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker    def test_onecycle_lr_linear_annealing_three_phases(self):
1654*da0073e9SAndroid Build Coastguard Worker        lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25]
1655*da0073e9SAndroid Build Coastguard Worker        momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22]
1656*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1657*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1658*da0073e9SAndroid Build Coastguard Worker        scheduler = OneCycleLR(
1659*da0073e9SAndroid Build Coastguard Worker            self.opt,
1660*da0073e9SAndroid Build Coastguard Worker            max_lr=25,
1661*da0073e9SAndroid Build Coastguard Worker            div_factor=25,
1662*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1663*da0073e9SAndroid Build Coastguard Worker            max_momentum=22,
1664*da0073e9SAndroid Build Coastguard Worker            total_steps=10,
1665*da0073e9SAndroid Build Coastguard Worker            anneal_strategy="linear",
1666*da0073e9SAndroid Build Coastguard Worker            pct_start=0.4,
1667*da0073e9SAndroid Build Coastguard Worker            final_div_factor=4,
1668*da0073e9SAndroid Build Coastguard Worker            three_phase=True,
1669*da0073e9SAndroid Build Coastguard Worker        )
1670*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1671*da0073e9SAndroid Build Coastguard Worker
1672*da0073e9SAndroid Build Coastguard Worker    def test_onecycle_lr_cosine_annealing(self):
1673*da0073e9SAndroid Build Coastguard Worker        def annealing_cos(start, end, pct):
1674*da0073e9SAndroid Build Coastguard Worker            cos_out = math.cos(math.pi * pct) + 1
1675*da0073e9SAndroid Build Coastguard Worker            return end + (start - end) / 2.0 * cos_out
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker        lr_target = [
1678*da0073e9SAndroid Build Coastguard Worker            1,
1679*da0073e9SAndroid Build Coastguard Worker            13,
1680*da0073e9SAndroid Build Coastguard Worker            25,
1681*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 1 / 7.0),
1682*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 2 / 7.0),
1683*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 3 / 7.0),
1684*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 4 / 7.0),
1685*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 5 / 7.0),
1686*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 6 / 7.0),
1687*da0073e9SAndroid Build Coastguard Worker            0.5,
1688*da0073e9SAndroid Build Coastguard Worker        ]
1689*da0073e9SAndroid Build Coastguard Worker        momentum_target = [
1690*da0073e9SAndroid Build Coastguard Worker            22,
1691*da0073e9SAndroid Build Coastguard Worker            11.5,
1692*da0073e9SAndroid Build Coastguard Worker            1,
1693*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 1 / 7.0),
1694*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 2 / 7.0),
1695*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 3 / 7.0),
1696*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 4 / 7.0),
1697*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 5 / 7.0),
1698*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 6 / 7.0),
1699*da0073e9SAndroid Build Coastguard Worker            22,
1700*da0073e9SAndroid Build Coastguard Worker        ]
1701*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1702*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1703*da0073e9SAndroid Build Coastguard Worker        scheduler = OneCycleLR(
1704*da0073e9SAndroid Build Coastguard Worker            self.opt,
1705*da0073e9SAndroid Build Coastguard Worker            max_lr=25,
1706*da0073e9SAndroid Build Coastguard Worker            final_div_factor=2,
1707*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1708*da0073e9SAndroid Build Coastguard Worker            max_momentum=22,
1709*da0073e9SAndroid Build Coastguard Worker            total_steps=10,
1710*da0073e9SAndroid Build Coastguard Worker        )
1711*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard Worker    def test_onecycle_lr_legacy_state_dict(self):
1714*da0073e9SAndroid Build Coastguard Worker        scheduler = OneCycleLR(
1715*da0073e9SAndroid Build Coastguard Worker            self.opt,
1716*da0073e9SAndroid Build Coastguard Worker            max_lr=25,
1717*da0073e9SAndroid Build Coastguard Worker            final_div_factor=2,
1718*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1719*da0073e9SAndroid Build Coastguard Worker            max_momentum=22,
1720*da0073e9SAndroid Build Coastguard Worker            total_steps=10,
1721*da0073e9SAndroid Build Coastguard Worker            anneal_strategy="cos",
1722*da0073e9SAndroid Build Coastguard Worker        )
1723*da0073e9SAndroid Build Coastguard Worker        delattr(scheduler, "_anneal_func_type")
1724*da0073e9SAndroid Build Coastguard Worker        state_dict = scheduler.state_dict()
1725*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("anneal_func_type", state_dict)
1726*da0073e9SAndroid Build Coastguard Worker        state_dict["anneal_func"] = OneCycleLR._annealing_cos
1727*da0073e9SAndroid Build Coastguard Worker        scheduler.load_state_dict(state_dict)
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker        def annealing_cos(start, end, pct):
1730*da0073e9SAndroid Build Coastguard Worker            cos_out = math.cos(math.pi * pct) + 1
1731*da0073e9SAndroid Build Coastguard Worker            return end + (start - end) / 2.0 * cos_out
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker        lr_target = [
1734*da0073e9SAndroid Build Coastguard Worker            1,
1735*da0073e9SAndroid Build Coastguard Worker            13,
1736*da0073e9SAndroid Build Coastguard Worker            25,
1737*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 1 / 7.0),
1738*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 2 / 7.0),
1739*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 3 / 7.0),
1740*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 4 / 7.0),
1741*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 5 / 7.0),
1742*da0073e9SAndroid Build Coastguard Worker            annealing_cos(25, 0.5, 6 / 7.0),
1743*da0073e9SAndroid Build Coastguard Worker            0.5,
1744*da0073e9SAndroid Build Coastguard Worker        ]
1745*da0073e9SAndroid Build Coastguard Worker        momentum_target = [
1746*da0073e9SAndroid Build Coastguard Worker            22,
1747*da0073e9SAndroid Build Coastguard Worker            11.5,
1748*da0073e9SAndroid Build Coastguard Worker            1,
1749*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 1 / 7.0),
1750*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 2 / 7.0),
1751*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 3 / 7.0),
1752*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 4 / 7.0),
1753*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 5 / 7.0),
1754*da0073e9SAndroid Build Coastguard Worker            annealing_cos(1, 22, 6 / 7.0),
1755*da0073e9SAndroid Build Coastguard Worker            22,
1756*da0073e9SAndroid Build Coastguard Worker        ]
1757*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1758*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1759*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker    def test_cycle_lr_with_adam(self):
1762*da0073e9SAndroid Build Coastguard Worker        old_opt = self.opt
1763*da0073e9SAndroid Build Coastguard Worker        self.opt = Adam(
1764*da0073e9SAndroid Build Coastguard Worker            [
1765*da0073e9SAndroid Build Coastguard Worker                {"params": self.net.conv1.parameters()},
1766*da0073e9SAndroid Build Coastguard Worker                {"params": self.net.conv2.parameters(), "lr": 0.5},
1767*da0073e9SAndroid Build Coastguard Worker            ],
1768*da0073e9SAndroid Build Coastguard Worker            lr=0.05,
1769*da0073e9SAndroid Build Coastguard Worker        )
1770*da0073e9SAndroid Build Coastguard Worker
1771*da0073e9SAndroid Build Coastguard Worker        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
1772*da0073e9SAndroid Build Coastguard Worker        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
1773*da0073e9SAndroid Build Coastguard Worker        lr_targets = [lr_target, lr_target]
1774*da0073e9SAndroid Build Coastguard Worker        momentum_targets = [momentum_target, momentum_target]
1775*da0073e9SAndroid Build Coastguard Worker        scheduler = OneCycleLR(
1776*da0073e9SAndroid Build Coastguard Worker            self.opt,
1777*da0073e9SAndroid Build Coastguard Worker            max_lr=25,
1778*da0073e9SAndroid Build Coastguard Worker            final_div_factor=2,
1779*da0073e9SAndroid Build Coastguard Worker            base_momentum=1,
1780*da0073e9SAndroid Build Coastguard Worker            max_momentum=22,
1781*da0073e9SAndroid Build Coastguard Worker            total_steps=10,
1782*da0073e9SAndroid Build Coastguard Worker            anneal_strategy="linear",
1783*da0073e9SAndroid Build Coastguard Worker        )
1784*da0073e9SAndroid Build Coastguard Worker        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
1785*da0073e9SAndroid Build Coastguard Worker        self.opt = old_opt  # set optimizer back to SGD
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker    def test_lambda_lr(self):
1788*da0073e9SAndroid Build Coastguard Worker        epochs = 10
1789*da0073e9SAndroid Build Coastguard Worker        self.opt.param_groups[0]["lr"] = 0.05
1790*da0073e9SAndroid Build Coastguard Worker        self.opt.param_groups[1]["lr"] = 0.4
1791*da0073e9SAndroid Build Coastguard Worker        targets = [
1792*da0073e9SAndroid Build Coastguard Worker            [0.05 * (0.9**x) for x in range(epochs)],
1793*da0073e9SAndroid Build Coastguard Worker            [0.4 * (0.8**x) for x in range(epochs)],
1794*da0073e9SAndroid Build Coastguard Worker        ]
1795*da0073e9SAndroid Build Coastguard Worker        scheduler = LambdaLR(
1796*da0073e9SAndroid Build Coastguard Worker            self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2]
1797*da0073e9SAndroid Build Coastguard Worker        )
1798*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
1799*da0073e9SAndroid Build Coastguard Worker
1800*da0073e9SAndroid Build Coastguard Worker    def test_multiplicative_lr(self):
1801*da0073e9SAndroid Build Coastguard Worker        epochs = 10
1802*da0073e9SAndroid Build Coastguard Worker        self.opt.param_groups[0]["lr"] = 0.05
1803*da0073e9SAndroid Build Coastguard Worker        self.opt.param_groups[1]["lr"] = 0.4
1804*da0073e9SAndroid Build Coastguard Worker        targets = [
1805*da0073e9SAndroid Build Coastguard Worker            [0.05 * (0.9**x) for x in range(epochs)],
1806*da0073e9SAndroid Build Coastguard Worker            [0.4 * (0.8**x) for x in range(epochs)],
1807*da0073e9SAndroid Build Coastguard Worker        ]
1808*da0073e9SAndroid Build Coastguard Worker        scheduler = MultiplicativeLR(
1809*da0073e9SAndroid Build Coastguard Worker            self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8]
1810*da0073e9SAndroid Build Coastguard Worker        )
1811*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, epochs)
1812*da0073e9SAndroid Build Coastguard Worker
1813*da0073e9SAndroid Build Coastguard Worker    @parametrize("T_mult", [1, 2, 4])
1814*da0073e9SAndroid Build Coastguard Worker    def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
1815*da0073e9SAndroid Build Coastguard Worker        iters = 100
1816*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
1817*da0073e9SAndroid Build Coastguard Worker        T_i = 10
1818*da0073e9SAndroid Build Coastguard Worker        T_cur = 0
1819*da0073e9SAndroid Build Coastguard Worker        targets = [[0.05], [0.5]]
1820*da0073e9SAndroid Build Coastguard Worker        scheduler = CosineAnnealingWarmRestarts(
1821*da0073e9SAndroid Build Coastguard Worker            self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
1822*da0073e9SAndroid Build Coastguard Worker        )
1823*da0073e9SAndroid Build Coastguard Worker        for _ in range(1, iters, 1):
1824*da0073e9SAndroid Build Coastguard Worker            T_cur += 1
1825*da0073e9SAndroid Build Coastguard Worker            if T_cur >= T_i:
1826*da0073e9SAndroid Build Coastguard Worker                T_cur = T_cur - T_i
1827*da0073e9SAndroid Build Coastguard Worker                T_i = int(T_mult) * T_i
1828*da0073e9SAndroid Build Coastguard Worker            targets[0] += [
1829*da0073e9SAndroid Build Coastguard Worker                eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1830*da0073e9SAndroid Build Coastguard Worker            ]
1831*da0073e9SAndroid Build Coastguard Worker            targets[1] += [
1832*da0073e9SAndroid Build Coastguard Worker                eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1833*da0073e9SAndroid Build Coastguard Worker            ]
1834*da0073e9SAndroid Build Coastguard Worker        self._test(scheduler, targets, iters)
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker    def test_CosineAnnealingWarmRestarts_lr2(self):
1837*da0073e9SAndroid Build Coastguard Worker        iters = 30
1838*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
1839*da0073e9SAndroid Build Coastguard Worker        T_mults = [1, 2, 4]
1840*da0073e9SAndroid Build Coastguard Worker        for T_mult in T_mults:
1841*da0073e9SAndroid Build Coastguard Worker            T_i = 10
1842*da0073e9SAndroid Build Coastguard Worker            T_cur = 0
1843*da0073e9SAndroid Build Coastguard Worker            targets = [[0.05], [0.5]]
1844*da0073e9SAndroid Build Coastguard Worker            scheduler = CosineAnnealingWarmRestarts(
1845*da0073e9SAndroid Build Coastguard Worker                self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
1846*da0073e9SAndroid Build Coastguard Worker            )
1847*da0073e9SAndroid Build Coastguard Worker            for _ in torch.arange(0.1, iters, 0.1):
1848*da0073e9SAndroid Build Coastguard Worker                T_cur = round(T_cur + 0.1, 1)
1849*da0073e9SAndroid Build Coastguard Worker                if T_cur >= T_i:
1850*da0073e9SAndroid Build Coastguard Worker                    T_cur = T_cur - T_i
1851*da0073e9SAndroid Build Coastguard Worker                    T_i = int(T_mult) * T_i
1852*da0073e9SAndroid Build Coastguard Worker                targets[0] += [
1853*da0073e9SAndroid Build Coastguard Worker                    eta_min
1854*da0073e9SAndroid Build Coastguard Worker                    + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1855*da0073e9SAndroid Build Coastguard Worker                ]
1856*da0073e9SAndroid Build Coastguard Worker                targets[1] += [
1857*da0073e9SAndroid Build Coastguard Worker                    eta_min
1858*da0073e9SAndroid Build Coastguard Worker                    + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1859*da0073e9SAndroid Build Coastguard Worker                ]
1860*da0073e9SAndroid Build Coastguard Worker            self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters)
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Worker    def test_CosineAnnealingWarmRestarts_lr3(self):
1863*da0073e9SAndroid Build Coastguard Worker        epochs_for_T_mults = [
1864*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13],
1865*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3],
1866*da0073e9SAndroid Build Coastguard Worker            [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50],
1867*da0073e9SAndroid Build Coastguard Worker        ]
1868*da0073e9SAndroid Build Coastguard Worker        T_curs_for_T_mults = [
1869*da0073e9SAndroid Build Coastguard Worker            [1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3],
1870*da0073e9SAndroid Build Coastguard Worker            [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3],
1871*da0073e9SAndroid Build Coastguard Worker            [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10],
1872*da0073e9SAndroid Build Coastguard Worker        ]
1873*da0073e9SAndroid Build Coastguard Worker        T_is_for_T_mults = [
1874*da0073e9SAndroid Build Coastguard Worker            [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
1875*da0073e9SAndroid Build Coastguard Worker            [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10],
1876*da0073e9SAndroid Build Coastguard Worker            [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90],
1877*da0073e9SAndroid Build Coastguard Worker        ]
1878*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
1879*da0073e9SAndroid Build Coastguard Worker        T_mults = [1, 2, 3]
1880*da0073e9SAndroid Build Coastguard Worker        for epochs, T_mult, T_curs, T_is in zip(
1881*da0073e9SAndroid Build Coastguard Worker            epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults
1882*da0073e9SAndroid Build Coastguard Worker        ):
1883*da0073e9SAndroid Build Coastguard Worker            targets = [[0.05], [0.5]]
1884*da0073e9SAndroid Build Coastguard Worker            scheduler = CosineAnnealingWarmRestarts(
1885*da0073e9SAndroid Build Coastguard Worker                self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min
1886*da0073e9SAndroid Build Coastguard Worker            )
1887*da0073e9SAndroid Build Coastguard Worker            for T_cur, T_i in zip(T_curs, T_is):
1888*da0073e9SAndroid Build Coastguard Worker                targets[0] += [
1889*da0073e9SAndroid Build Coastguard Worker                    eta_min
1890*da0073e9SAndroid Build Coastguard Worker                    + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1891*da0073e9SAndroid Build Coastguard Worker                ]
1892*da0073e9SAndroid Build Coastguard Worker                targets[1] += [
1893*da0073e9SAndroid Build Coastguard Worker                    eta_min
1894*da0073e9SAndroid Build Coastguard Worker                    + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1895*da0073e9SAndroid Build Coastguard Worker                ]
1896*da0073e9SAndroid Build Coastguard Worker            self._test_interleaved_CosineAnnealingWarmRestarts(
1897*da0073e9SAndroid Build Coastguard Worker                scheduler, targets, epochs
1898*da0073e9SAndroid Build Coastguard Worker            )
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker    def test_swalr_no_anneal(self):
1901*da0073e9SAndroid Build Coastguard Worker        epochs, swa_start, swa_lr = 10, 5, 0.01
1902*da0073e9SAndroid Build Coastguard Worker        initial_lrs = [group["lr"] for group in self.opt.param_groups]
1903*da0073e9SAndroid Build Coastguard Worker        targets = [
1904*da0073e9SAndroid Build Coastguard Worker            [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1)
1905*da0073e9SAndroid Build Coastguard Worker            for lr in initial_lrs
1906*da0073e9SAndroid Build Coastguard Worker        ]
1907*da0073e9SAndroid Build Coastguard Worker        swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr)
1908*da0073e9SAndroid Build Coastguard Worker        self._test_swalr(swa_scheduler, None, targets, swa_start, epochs)
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker    def test_swalr_cosine_anneal_after_multiplicative(self):
1911*da0073e9SAndroid Build Coastguard Worker        # same swa_lr for different param_groups
1912*da0073e9SAndroid Build Coastguard Worker        epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5
1913*da0073e9SAndroid Build Coastguard Worker        mult_factor = 0.9
1914*da0073e9SAndroid Build Coastguard Worker        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
1915*da0073e9SAndroid Build Coastguard Worker        swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr)
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker        def anneal_coef(t):
1918*da0073e9SAndroid Build Coastguard Worker            if t + 1 >= anneal_epochs:
1919*da0073e9SAndroid Build Coastguard Worker                return 0.0
1920*da0073e9SAndroid Build Coastguard Worker            return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker        initial_lrs = [group["lr"] for group in self.opt.param_groups]
1923*da0073e9SAndroid Build Coastguard Worker        targets_before_swa = [
1924*da0073e9SAndroid Build Coastguard Worker            [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
1925*da0073e9SAndroid Build Coastguard Worker        ]
1926*da0073e9SAndroid Build Coastguard Worker        swa_epochs = epochs - swa_start - 1
1927*da0073e9SAndroid Build Coastguard Worker        targets = [
1928*da0073e9SAndroid Build Coastguard Worker            lrs
1929*da0073e9SAndroid Build Coastguard Worker            + [
1930*da0073e9SAndroid Build Coastguard Worker                lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
1931*da0073e9SAndroid Build Coastguard Worker                for t in range(swa_epochs)
1932*da0073e9SAndroid Build Coastguard Worker            ]
1933*da0073e9SAndroid Build Coastguard Worker            for lrs in targets_before_swa
1934*da0073e9SAndroid Build Coastguard Worker        ]
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)
1937*da0073e9SAndroid Build Coastguard Worker
1938*da0073e9SAndroid Build Coastguard Worker    def test_swalr_linear_anneal_after_multiplicative(self):
1939*da0073e9SAndroid Build Coastguard Worker        # separate swa_lr for different param_groups
1940*da0073e9SAndroid Build Coastguard Worker        epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4
1941*da0073e9SAndroid Build Coastguard Worker        mult_factor = 0.9
1942*da0073e9SAndroid Build Coastguard Worker        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
1943*da0073e9SAndroid Build Coastguard Worker        swa_scheduler = SWALR(
1944*da0073e9SAndroid Build Coastguard Worker            self.opt,
1945*da0073e9SAndroid Build Coastguard Worker            anneal_epochs=anneal_epochs,
1946*da0073e9SAndroid Build Coastguard Worker            anneal_strategy="linear",
1947*da0073e9SAndroid Build Coastguard Worker            swa_lr=swa_lrs,
1948*da0073e9SAndroid Build Coastguard Worker        )
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker        def anneal_coef(t):
1951*da0073e9SAndroid Build Coastguard Worker            if t + 1 >= anneal_epochs:
1952*da0073e9SAndroid Build Coastguard Worker                return 0.0
1953*da0073e9SAndroid Build Coastguard Worker            return 1 - (t + 1) / anneal_epochs
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker        initial_lrs = [group["lr"] for group in self.opt.param_groups]
1956*da0073e9SAndroid Build Coastguard Worker        targets_before_swa = [
1957*da0073e9SAndroid Build Coastguard Worker            [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
1958*da0073e9SAndroid Build Coastguard Worker        ]
1959*da0073e9SAndroid Build Coastguard Worker        swa_epochs = epochs - swa_start - 1
1960*da0073e9SAndroid Build Coastguard Worker        targets = [
1961*da0073e9SAndroid Build Coastguard Worker            lrs
1962*da0073e9SAndroid Build Coastguard Worker            + [
1963*da0073e9SAndroid Build Coastguard Worker                lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
1964*da0073e9SAndroid Build Coastguard Worker                for t in range(swa_epochs)
1965*da0073e9SAndroid Build Coastguard Worker            ]
1966*da0073e9SAndroid Build Coastguard Worker            for lrs, swa_lr in zip(targets_before_swa, swa_lrs)
1967*da0073e9SAndroid Build Coastguard Worker        ]
1968*da0073e9SAndroid Build Coastguard Worker
1969*da0073e9SAndroid Build Coastguard Worker        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)
1970*da0073e9SAndroid Build Coastguard Worker
1971*da0073e9SAndroid Build Coastguard Worker    def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs):
1972*da0073e9SAndroid Build Coastguard Worker        for epoch in range(epochs):
1973*da0073e9SAndroid Build Coastguard Worker            for param_group, target in zip(self.opt.param_groups, targets):
1974*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1975*da0073e9SAndroid Build Coastguard Worker                    target[epoch],
1976*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
1977*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
1978*da0073e9SAndroid Build Coastguard Worker                        epoch, target[epoch], param_group["lr"]
1979*da0073e9SAndroid Build Coastguard Worker                    ),
1980*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
1981*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
1982*da0073e9SAndroid Build Coastguard Worker                )
1983*da0073e9SAndroid Build Coastguard Worker            if epoch >= swa_start:
1984*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
1985*da0073e9SAndroid Build Coastguard Worker                swa_scheduler.step()
1986*da0073e9SAndroid Build Coastguard Worker            elif scheduler is not None:
1987*da0073e9SAndroid Build Coastguard Worker                self.opt.step()
1988*da0073e9SAndroid Build Coastguard Worker                scheduler.step()
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker    def test_swalr_hypers(self):
1991*da0073e9SAndroid Build Coastguard Worker        # Test that SWALR raises errors for incorrect hyper-parameters
1992*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "anneal_strategy must"):
1993*da0073e9SAndroid Build Coastguard Worker            swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0)
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
1996*da0073e9SAndroid Build Coastguard Worker            swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0)
1997*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
1998*da0073e9SAndroid Build Coastguard Worker            swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0)
1999*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "swa_lr must"):
2000*da0073e9SAndroid Build Coastguard Worker            swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01])
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Worker    def test_step_lr_state_dict(self):
2003*da0073e9SAndroid Build Coastguard Worker        self._check_scheduler_state_dict(
2004*da0073e9SAndroid Build Coastguard Worker            lambda: StepLR(self.opt, gamma=0.1, step_size=3),
2005*da0073e9SAndroid Build Coastguard Worker            lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1),
2006*da0073e9SAndroid Build Coastguard Worker        )
2007*da0073e9SAndroid Build Coastguard Worker
2008*da0073e9SAndroid Build Coastguard Worker    def test_multi_step_lr_state_dict(self):
2009*da0073e9SAndroid Build Coastguard Worker        self._check_scheduler_state_dict(
2010*da0073e9SAndroid Build Coastguard Worker            lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
2011*da0073e9SAndroid Build Coastguard Worker            lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]),
2012*da0073e9SAndroid Build Coastguard Worker        )
2013*da0073e9SAndroid Build Coastguard Worker
2014*da0073e9SAndroid Build Coastguard Worker    def test_exp_step_lr_state_dict(self):
2015*da0073e9SAndroid Build Coastguard Worker        self._check_scheduler_state_dict(
2016*da0073e9SAndroid Build Coastguard Worker            lambda: ExponentialLR(self.opt, gamma=0.1),
2017*da0073e9SAndroid Build Coastguard Worker            lambda: ExponentialLR(self.opt, gamma=0.01),
2018*da0073e9SAndroid Build Coastguard Worker        )
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker    def test_cosine_lr_state_dict(self):
2021*da0073e9SAndroid Build Coastguard Worker        epochs = 10
2022*da0073e9SAndroid Build Coastguard Worker        eta_min = 1e-10
2023*da0073e9SAndroid Build Coastguard Worker        self._check_scheduler_state_dict(
2024*da0073e9SAndroid Build Coastguard Worker            lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
2025*da0073e9SAndroid Build Coastguard Worker            lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
2026*da0073e9SAndroid Build Coastguard Worker            epochs=epochs,
2027*da0073e9SAndroid Build Coastguard Worker        )
2028*da0073e9SAndroid Build Coastguard Worker
2029*da0073e9SAndroid Build Coastguard Worker    def test_reduce_lr_on_plateau_state_dict(self):
2030*da0073e9SAndroid Build Coastguard Worker        scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2)
2031*da0073e9SAndroid Build Coastguard Worker        for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]:
2032*da0073e9SAndroid Build Coastguard Worker            scheduler.step(score)
2033*da0073e9SAndroid Build Coastguard Worker        scheduler_copy = ReduceLROnPlateau(
2034*da0073e9SAndroid Build Coastguard Worker            self.opt, mode="max", factor=0.5, patience=10
2035*da0073e9SAndroid Build Coastguard Worker        )
2036*da0073e9SAndroid Build Coastguard Worker        scheduler_copy.load_state_dict(scheduler.state_dict())
2037*da0073e9SAndroid Build Coastguard Worker        for key in scheduler.__dict__.keys():
2038*da0073e9SAndroid Build Coastguard Worker            if key not in {"optimizer", "is_better"}:
2039*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2040*da0073e9SAndroid Build Coastguard Worker
2041*da0073e9SAndroid Build Coastguard Worker    def test_lambda_lr_state_dict_fn(self):
2042*da0073e9SAndroid Build Coastguard Worker        scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x)
2043*da0073e9SAndroid Build Coastguard Worker        state = scheduler.state_dict()
2044*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(state["lr_lambdas"][0])
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker        scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x)
2047*da0073e9SAndroid Build Coastguard Worker        scheduler_copy.load_state_dict(state)
2048*da0073e9SAndroid Build Coastguard Worker        for key in scheduler.__dict__.keys():
2049*da0073e9SAndroid Build Coastguard Worker            if key not in {"optimizer", "lr_lambdas"}:
2050*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2051*da0073e9SAndroid Build Coastguard Worker
2052*da0073e9SAndroid Build Coastguard Worker    def test_lambda_lr_state_dict_obj(self):
2053*da0073e9SAndroid Build Coastguard Worker        scheduler = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(10))
2054*da0073e9SAndroid Build Coastguard Worker        state = scheduler.state_dict()
2055*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(state["lr_lambdas"][0])
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Worker        scheduler_copy = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(-1))
2058*da0073e9SAndroid Build Coastguard Worker        scheduler_copy.load_state_dict(state)
2059*da0073e9SAndroid Build Coastguard Worker        for key in scheduler.__dict__.keys():
2060*da0073e9SAndroid Build Coastguard Worker            if key not in {"optimizer"}:
2061*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2062*da0073e9SAndroid Build Coastguard Worker
2063*da0073e9SAndroid Build Coastguard Worker    def test_CosineAnnealingWarmRestarts_lr_state_dict(self):
2064*da0073e9SAndroid Build Coastguard Worker        self._check_scheduler_state_dict(
2065*da0073e9SAndroid Build Coastguard Worker            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2),
2066*da0073e9SAndroid Build Coastguard Worker            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100),
2067*da0073e9SAndroid Build Coastguard Worker        )
2068*da0073e9SAndroid Build Coastguard Worker
2069*da0073e9SAndroid Build Coastguard Worker    def test_swa_lr_state_dict(self):
2070*da0073e9SAndroid Build Coastguard Worker        self._check_scheduler_state_dict(
2071*da0073e9SAndroid Build Coastguard Worker            lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5),
2072*da0073e9SAndroid Build Coastguard Worker            lambda: SWALR(
2073*da0073e9SAndroid Build Coastguard Worker                self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0
2074*da0073e9SAndroid Build Coastguard Worker            ),
2075*da0073e9SAndroid Build Coastguard Worker        )
2076*da0073e9SAndroid Build Coastguard Worker
2077*da0073e9SAndroid Build Coastguard Worker    def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
2078*da0073e9SAndroid Build Coastguard Worker        scheduler = constr()
2079*da0073e9SAndroid Build Coastguard Worker        for _ in range(epochs):
2080*da0073e9SAndroid Build Coastguard Worker            scheduler.optimizer.step()
2081*da0073e9SAndroid Build Coastguard Worker            scheduler.step()
2082*da0073e9SAndroid Build Coastguard Worker        scheduler_copy = constr2()
2083*da0073e9SAndroid Build Coastguard Worker        scheduler_copy.load_state_dict(scheduler.state_dict())
2084*da0073e9SAndroid Build Coastguard Worker        for key in scheduler.__dict__.keys():
2085*da0073e9SAndroid Build Coastguard Worker            if key != "optimizer":
2086*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr())
2088*da0073e9SAndroid Build Coastguard Worker
2089*da0073e9SAndroid Build Coastguard Worker    def _test_get_last_lr(self, schedulers, targets, epochs=10):
2090*da0073e9SAndroid Build Coastguard Worker        if isinstance(schedulers, LRScheduler):
2091*da0073e9SAndroid Build Coastguard Worker            schedulers = [schedulers]
2092*da0073e9SAndroid Build Coastguard Worker        optimizers = {scheduler.optimizer for scheduler in schedulers}
2093*da0073e9SAndroid Build Coastguard Worker        for epoch in range(epochs):
2094*da0073e9SAndroid Build Coastguard Worker            result = [scheduler.get_last_lr() for scheduler in schedulers]
2095*da0073e9SAndroid Build Coastguard Worker            [optimizer.step() for optimizer in optimizers]
2096*da0073e9SAndroid Build Coastguard Worker            [scheduler.step() for scheduler in schedulers]
2097*da0073e9SAndroid Build Coastguard Worker            target = [[t[epoch] for t in targets]] * len(schedulers)
2098*da0073e9SAndroid Build Coastguard Worker            for t, r in zip(target, result):
2099*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2100*da0073e9SAndroid Build Coastguard Worker                    t,
2101*da0073e9SAndroid Build Coastguard Worker                    r,
2102*da0073e9SAndroid Build Coastguard Worker                    msg=f"LR is wrong in epoch {epoch}: expected {t}, got {r}",
2103*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2104*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2105*da0073e9SAndroid Build Coastguard Worker                )
2106*da0073e9SAndroid Build Coastguard Worker
2107*da0073e9SAndroid Build Coastguard Worker    def _test_with_epoch(self, schedulers, targets, epochs=10):
2108*da0073e9SAndroid Build Coastguard Worker        if isinstance(schedulers, LRScheduler):
2109*da0073e9SAndroid Build Coastguard Worker            schedulers = [schedulers]
2110*da0073e9SAndroid Build Coastguard Worker        optimizers = {scheduler.optimizer for scheduler in schedulers}
2111*da0073e9SAndroid Build Coastguard Worker        for epoch in range(epochs):
2112*da0073e9SAndroid Build Coastguard Worker            [optimizer.step() for optimizer in optimizers]
2113*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
2114*da0073e9SAndroid Build Coastguard Worker                [
2115*da0073e9SAndroid Build Coastguard Worker                    scheduler.step(epoch) for scheduler in schedulers
2116*da0073e9SAndroid Build Coastguard Worker                ]  # step before assert: skip initial lr
2117*da0073e9SAndroid Build Coastguard Worker                self._check_warning_is_epoch_deprecation_warning(
2118*da0073e9SAndroid Build Coastguard Worker                    w, num_warnings=len(schedulers)
2119*da0073e9SAndroid Build Coastguard Worker                )
2120*da0073e9SAndroid Build Coastguard Worker            for param_group, target in zip(self.opt.param_groups, targets):
2121*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2122*da0073e9SAndroid Build Coastguard Worker                    target[epoch],
2123*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
2124*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2125*da0073e9SAndroid Build Coastguard Worker                        epoch, target[epoch], param_group["lr"]
2126*da0073e9SAndroid Build Coastguard Worker                    ),
2127*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2128*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2129*da0073e9SAndroid Build Coastguard Worker                )
2130*da0073e9SAndroid Build Coastguard Worker
2131*da0073e9SAndroid Build Coastguard Worker    def _test(self, schedulers, targets, epochs=10):
2132*da0073e9SAndroid Build Coastguard Worker        if isinstance(schedulers, LRScheduler):
2133*da0073e9SAndroid Build Coastguard Worker            schedulers = [schedulers]
2134*da0073e9SAndroid Build Coastguard Worker        for epoch in range(epochs):
2135*da0073e9SAndroid Build Coastguard Worker            for param_group, target in zip(self.opt.param_groups, targets):
2136*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2137*da0073e9SAndroid Build Coastguard Worker                    target[epoch],
2138*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
2139*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2140*da0073e9SAndroid Build Coastguard Worker                        epoch, target[epoch], param_group["lr"]
2141*da0073e9SAndroid Build Coastguard Worker                    ),
2142*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2143*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2144*da0073e9SAndroid Build Coastguard Worker                )
2145*da0073e9SAndroid Build Coastguard Worker            [scheduler.step() for scheduler in schedulers]
2146*da0073e9SAndroid Build Coastguard Worker
2147*da0073e9SAndroid Build Coastguard Worker    def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10):
2148*da0073e9SAndroid Build Coastguard Worker        for index, epoch in enumerate(torch.arange(0, epochs, 0.1)):
2149*da0073e9SAndroid Build Coastguard Worker            epoch = round(epoch.item(), 1)
2150*da0073e9SAndroid Build Coastguard Worker            scheduler.step(epoch)
2151*da0073e9SAndroid Build Coastguard Worker            for param_group, target in zip(self.opt.param_groups, targets):
2152*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2153*da0073e9SAndroid Build Coastguard Worker                    target[index],
2154*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
2155*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2156*da0073e9SAndroid Build Coastguard Worker                        epoch, target[index], param_group["lr"]
2157*da0073e9SAndroid Build Coastguard Worker                    ),
2158*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2159*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2160*da0073e9SAndroid Build Coastguard Worker                )
2161*da0073e9SAndroid Build Coastguard Worker
2162*da0073e9SAndroid Build Coastguard Worker    def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs):
2163*da0073e9SAndroid Build Coastguard Worker        for index, epoch in enumerate(epochs):
2164*da0073e9SAndroid Build Coastguard Worker            scheduler.step(epoch)
2165*da0073e9SAndroid Build Coastguard Worker            for param_group, target in zip(self.opt.param_groups, targets):
2166*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2167*da0073e9SAndroid Build Coastguard Worker                    target[index],
2168*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
2169*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2170*da0073e9SAndroid Build Coastguard Worker                        epoch, target[index], param_group["lr"]
2171*da0073e9SAndroid Build Coastguard Worker                    ),
2172*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2173*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2174*da0073e9SAndroid Build Coastguard Worker                )
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker    def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10):
2177*da0073e9SAndroid Build Coastguard Worker        self.setUp()
2178*da0073e9SAndroid Build Coastguard Worker        targets = []
2179*da0073e9SAndroid Build Coastguard Worker        for epoch in range(epochs):
2180*da0073e9SAndroid Build Coastguard Worker            closed_form_scheduler.optimizer.step()
2181*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
2182*da0073e9SAndroid Build Coastguard Worker                closed_form_scheduler.step(epoch)
2183*da0073e9SAndroid Build Coastguard Worker                self._check_warning_is_epoch_deprecation_warning(w)
2184*da0073e9SAndroid Build Coastguard Worker            targets.append([group["lr"] for group in self.opt.param_groups])
2185*da0073e9SAndroid Build Coastguard Worker        self.setUp()
2186*da0073e9SAndroid Build Coastguard Worker        for epoch in range(epochs):
2187*da0073e9SAndroid Build Coastguard Worker            self.opt.step()
2188*da0073e9SAndroid Build Coastguard Worker            scheduler.step()
2189*da0073e9SAndroid Build Coastguard Worker            for i, param_group in enumerate(self.opt.param_groups):
2190*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2191*da0073e9SAndroid Build Coastguard Worker                    targets[epoch][i],
2192*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
2193*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2194*da0073e9SAndroid Build Coastguard Worker                        epoch, targets[epoch][i], param_group["lr"]
2195*da0073e9SAndroid Build Coastguard Worker                    ),
2196*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2197*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2198*da0073e9SAndroid Build Coastguard Worker                )
2199*da0073e9SAndroid Build Coastguard Worker
2200*da0073e9SAndroid Build Coastguard Worker    def _test_reduce_lr_on_plateau(
2201*da0073e9SAndroid Build Coastguard Worker        self, schedulers, targets, metrics, epochs=10, verbose=False
2202*da0073e9SAndroid Build Coastguard Worker    ):
2203*da0073e9SAndroid Build Coastguard Worker        if isinstance(schedulers, (LRScheduler, ReduceLROnPlateau)):
2204*da0073e9SAndroid Build Coastguard Worker            schedulers = [schedulers]
2205*da0073e9SAndroid Build Coastguard Worker        for epoch in range(epochs):
2206*da0073e9SAndroid Build Coastguard Worker            self.opt.step()
2207*da0073e9SAndroid Build Coastguard Worker            for scheduler in schedulers:
2208*da0073e9SAndroid Build Coastguard Worker                if isinstance(scheduler, ReduceLROnPlateau):
2209*da0073e9SAndroid Build Coastguard Worker                    scheduler.step(metrics[epoch])
2210*da0073e9SAndroid Build Coastguard Worker                else:
2211*da0073e9SAndroid Build Coastguard Worker                    scheduler.step()
2212*da0073e9SAndroid Build Coastguard Worker            if verbose:
2213*da0073e9SAndroid Build Coastguard Worker                print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"]))
2214*da0073e9SAndroid Build Coastguard Worker            for param_group, target in zip(self.opt.param_groups, targets):
2215*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2216*da0073e9SAndroid Build Coastguard Worker                    target[epoch],
2217*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
2218*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2219*da0073e9SAndroid Build Coastguard Worker                        epoch, target[epoch], param_group["lr"]
2220*da0073e9SAndroid Build Coastguard Worker                    ),
2221*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2222*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2223*da0073e9SAndroid Build Coastguard Worker                )
2224*da0073e9SAndroid Build Coastguard Worker
2225*da0073e9SAndroid Build Coastguard Worker    def _test_cycle_lr(
2226*da0073e9SAndroid Build Coastguard Worker        self,
2227*da0073e9SAndroid Build Coastguard Worker        scheduler,
2228*da0073e9SAndroid Build Coastguard Worker        lr_targets,
2229*da0073e9SAndroid Build Coastguard Worker        momentum_targets,
2230*da0073e9SAndroid Build Coastguard Worker        batch_iterations,
2231*da0073e9SAndroid Build Coastguard Worker        verbose=False,
2232*da0073e9SAndroid Build Coastguard Worker        use_beta1=False,
2233*da0073e9SAndroid Build Coastguard Worker    ):
2234*da0073e9SAndroid Build Coastguard Worker        for batch_num in range(batch_iterations):
2235*da0073e9SAndroid Build Coastguard Worker            if verbose:
2236*da0073e9SAndroid Build Coastguard Worker                if "momentum" in self.opt.param_groups[0].keys():
2237*da0073e9SAndroid Build Coastguard Worker                    print(
2238*da0073e9SAndroid Build Coastguard Worker                        "batch{}:\tlr={},momentum={}".format(
2239*da0073e9SAndroid Build Coastguard Worker                            batch_num,
2240*da0073e9SAndroid Build Coastguard Worker                            self.opt.param_groups[0]["lr"],
2241*da0073e9SAndroid Build Coastguard Worker                            self.opt.param_groups[0]["momentum"],
2242*da0073e9SAndroid Build Coastguard Worker                        )
2243*da0073e9SAndroid Build Coastguard Worker                    )
2244*da0073e9SAndroid Build Coastguard Worker                elif use_beta1 and "betas" in self.opt.param_groups[0].keys():
2245*da0073e9SAndroid Build Coastguard Worker                    print(
2246*da0073e9SAndroid Build Coastguard Worker                        "batch{}:\tlr={},beta1={}".format(
2247*da0073e9SAndroid Build Coastguard Worker                            batch_num,
2248*da0073e9SAndroid Build Coastguard Worker                            self.opt.param_groups[0]["lr"],
2249*da0073e9SAndroid Build Coastguard Worker                            self.opt.param_groups[0]["betas"][0],
2250*da0073e9SAndroid Build Coastguard Worker                        )
2251*da0073e9SAndroid Build Coastguard Worker                    )
2252*da0073e9SAndroid Build Coastguard Worker                else:
2253*da0073e9SAndroid Build Coastguard Worker                    print(
2254*da0073e9SAndroid Build Coastguard Worker                        "batch{}:\tlr={}".format(
2255*da0073e9SAndroid Build Coastguard Worker                            batch_num, self.opt.param_groups[0]["lr"]
2256*da0073e9SAndroid Build Coastguard Worker                        )
2257*da0073e9SAndroid Build Coastguard Worker                    )
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker            for param_group, lr_target, momentum_target in zip(
2260*da0073e9SAndroid Build Coastguard Worker                self.opt.param_groups, lr_targets, momentum_targets
2261*da0073e9SAndroid Build Coastguard Worker            ):
2262*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2263*da0073e9SAndroid Build Coastguard Worker                    lr_target[batch_num],
2264*da0073e9SAndroid Build Coastguard Worker                    param_group["lr"],
2265*da0073e9SAndroid Build Coastguard Worker                    msg="LR is wrong in batch_num {}: expected {}, got {}".format(
2266*da0073e9SAndroid Build Coastguard Worker                        batch_num, lr_target[batch_num], param_group["lr"]
2267*da0073e9SAndroid Build Coastguard Worker                    ),
2268*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
2269*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2270*da0073e9SAndroid Build Coastguard Worker                )
2271*da0073e9SAndroid Build Coastguard Worker
2272*da0073e9SAndroid Build Coastguard Worker                if use_beta1 and "betas" in param_group.keys():
2273*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2274*da0073e9SAndroid Build Coastguard Worker                        momentum_target[batch_num],
2275*da0073e9SAndroid Build Coastguard Worker                        param_group["betas"][0],
2276*da0073e9SAndroid Build Coastguard Worker                        msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format(
2277*da0073e9SAndroid Build Coastguard Worker                            batch_num,
2278*da0073e9SAndroid Build Coastguard Worker                            momentum_target[batch_num],
2279*da0073e9SAndroid Build Coastguard Worker                            param_group["betas"][0],
2280*da0073e9SAndroid Build Coastguard Worker                        ),
2281*da0073e9SAndroid Build Coastguard Worker                        atol=1e-5,
2282*da0073e9SAndroid Build Coastguard Worker                        rtol=0,
2283*da0073e9SAndroid Build Coastguard Worker                    )
2284*da0073e9SAndroid Build Coastguard Worker                elif "momentum" in param_group.keys():
2285*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2286*da0073e9SAndroid Build Coastguard Worker                        momentum_target[batch_num],
2287*da0073e9SAndroid Build Coastguard Worker                        param_group["momentum"],
2288*da0073e9SAndroid Build Coastguard Worker                        msg="Momentum is wrong in batch_num {}: expected {}, got {}".format(
2289*da0073e9SAndroid Build Coastguard Worker                            batch_num,
2290*da0073e9SAndroid Build Coastguard Worker                            momentum_target[batch_num],
2291*da0073e9SAndroid Build Coastguard Worker                            param_group["momentum"],
2292*da0073e9SAndroid Build Coastguard Worker                        ),
2293*da0073e9SAndroid Build Coastguard Worker                        atol=1e-5,
2294*da0073e9SAndroid Build Coastguard Worker                        rtol=0,
2295*da0073e9SAndroid Build Coastguard Worker                    )
2296*da0073e9SAndroid Build Coastguard Worker            self.opt.step()
2297*da0073e9SAndroid Build Coastguard Worker            scheduler.step()
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker    def test_cosine_then_cyclic(self):
2300*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/21965
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker        max_lr = 0.3
2303*da0073e9SAndroid Build Coastguard Worker        base_lr = 0.1
2304*da0073e9SAndroid Build Coastguard Worker        optim_lr = 0.5
2305*da0073e9SAndroid Build Coastguard Worker
2306*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(2, 1)
2307*da0073e9SAndroid Build Coastguard Worker        optimizer = SGD(model.parameters(), lr=optim_lr)
2308*da0073e9SAndroid Build Coastguard Worker        lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(
2309*da0073e9SAndroid Build Coastguard Worker            optimizer, T_max=20, eta_min=0.1
2310*da0073e9SAndroid Build Coastguard Worker        )
2311*da0073e9SAndroid Build Coastguard Worker        lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR(
2312*da0073e9SAndroid Build Coastguard Worker            optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3
2313*da0073e9SAndroid Build Coastguard Worker        )
2314*da0073e9SAndroid Build Coastguard Worker
2315*da0073e9SAndroid Build Coastguard Worker        for i in range(40):
2316*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
2317*da0073e9SAndroid Build Coastguard Worker            if i <= lr_scheduler_1.T_max:
2318*da0073e9SAndroid Build Coastguard Worker                lr_scheduler_1.step()
2319*da0073e9SAndroid Build Coastguard Worker            else:
2320*da0073e9SAndroid Build Coastguard Worker                lr_scheduler_2.step()
2321*da0073e9SAndroid Build Coastguard Worker            last_lr = optimizer.param_groups[0]["lr"]
2322*da0073e9SAndroid Build Coastguard Worker
2323*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(last_lr, max_lr)
2324*da0073e9SAndroid Build Coastguard Worker
2325*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2326*da0073e9SAndroid Build Coastguard Worker        "LRClass",
2327*da0073e9SAndroid Build Coastguard Worker        [
2328*da0073e9SAndroid Build Coastguard Worker            partial(LambdaLR, lr_lambda=lambda e: e // 10),
2329*da0073e9SAndroid Build Coastguard Worker            partial(MultiplicativeLR, lr_lambda=lambda: 0.95),
2330*da0073e9SAndroid Build Coastguard Worker            partial(StepLR, step_size=30),
2331*da0073e9SAndroid Build Coastguard Worker            partial(MultiStepLR, milestones=[30, 80]),
2332*da0073e9SAndroid Build Coastguard Worker            ConstantLR,
2333*da0073e9SAndroid Build Coastguard Worker            LinearLR,
2334*da0073e9SAndroid Build Coastguard Worker            partial(ExponentialLR, gamma=0.9),
2335*da0073e9SAndroid Build Coastguard Worker            lambda opt, **kwargs: SequentialLR(
2336*da0073e9SAndroid Build Coastguard Worker                opt,
2337*da0073e9SAndroid Build Coastguard Worker                schedulers=[ConstantLR(opt), ConstantLR(opt)],
2338*da0073e9SAndroid Build Coastguard Worker                milestones=[2],
2339*da0073e9SAndroid Build Coastguard Worker                **kwargs,
2340*da0073e9SAndroid Build Coastguard Worker            ),
2341*da0073e9SAndroid Build Coastguard Worker            PolynomialLR,
2342*da0073e9SAndroid Build Coastguard Worker            partial(CosineAnnealingLR, T_max=10),
2343*da0073e9SAndroid Build Coastguard Worker            ReduceLROnPlateau,
2344*da0073e9SAndroid Build Coastguard Worker            partial(CyclicLR, base_lr=0.01, max_lr=0.1),
2345*da0073e9SAndroid Build Coastguard Worker            partial(CosineAnnealingWarmRestarts, T_0=20),
2346*da0073e9SAndroid Build Coastguard Worker            partial(OneCycleLR, max_lr=0.01, total_steps=10),
2347*da0073e9SAndroid Build Coastguard Worker        ],
2348*da0073e9SAndroid Build Coastguard Worker    )
2349*da0073e9SAndroid Build Coastguard Worker    def test_lr_scheduler_verbose_deprecation_warning(self, LRClass):
2350*da0073e9SAndroid Build Coastguard Worker        """Check that a deprecating warning with verbose parameter."""
2351*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
2352*da0073e9SAndroid Build Coastguard Worker            UserWarning, "The verbose parameter is deprecated"
2353*da0073e9SAndroid Build Coastguard Worker        ):
2354*da0073e9SAndroid Build Coastguard Worker            LRClass(self.opt, verbose=True)
2355*da0073e9SAndroid Build Coastguard Worker
2356*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
2357*da0073e9SAndroid Build Coastguard Worker            UserWarning, "The verbose parameter is deprecated"
2358*da0073e9SAndroid Build Coastguard Worker        ):
2359*da0073e9SAndroid Build Coastguard Worker            LRClass(self.opt, verbose=False)
2360*da0073e9SAndroid Build Coastguard Worker
2361*da0073e9SAndroid Build Coastguard Worker        # No warning is raised when verbose is the default value.
2362*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings():
2363*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("error", UserWarning)
2364*da0073e9SAndroid Build Coastguard Worker            LRClass(self.opt)
2365*da0073e9SAndroid Build Coastguard Worker
2366*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2367*da0073e9SAndroid Build Coastguard Worker        "LRClass",
2368*da0073e9SAndroid Build Coastguard Worker        [
2369*da0073e9SAndroid Build Coastguard Worker            partial(LambdaLR, lr_lambda=lambda e: e // 10),
2370*da0073e9SAndroid Build Coastguard Worker            partial(MultiplicativeLR, lr_lambda=lambda: 0.95),
2371*da0073e9SAndroid Build Coastguard Worker            partial(StepLR, step_size=30),
2372*da0073e9SAndroid Build Coastguard Worker            partial(MultiStepLR, milestones=[30, 80]),
2373*da0073e9SAndroid Build Coastguard Worker            ConstantLR,
2374*da0073e9SAndroid Build Coastguard Worker            LinearLR,
2375*da0073e9SAndroid Build Coastguard Worker            partial(ExponentialLR, gamma=0.9),
2376*da0073e9SAndroid Build Coastguard Worker            PolynomialLR,
2377*da0073e9SAndroid Build Coastguard Worker            partial(CosineAnnealingLR, T_max=10),
2378*da0073e9SAndroid Build Coastguard Worker            lambda opt, **kwargs: ChainedScheduler(
2379*da0073e9SAndroid Build Coastguard Worker                schedulers=[ConstantLR(opt), ConstantLR(opt)], **kwargs
2380*da0073e9SAndroid Build Coastguard Worker            ),
2381*da0073e9SAndroid Build Coastguard Worker            lambda opt, **kwargs: SequentialLR(
2382*da0073e9SAndroid Build Coastguard Worker                opt,
2383*da0073e9SAndroid Build Coastguard Worker                schedulers=[ConstantLR(opt), ConstantLR(opt)],
2384*da0073e9SAndroid Build Coastguard Worker                milestones=[2],
2385*da0073e9SAndroid Build Coastguard Worker                **kwargs,
2386*da0073e9SAndroid Build Coastguard Worker            ),
2387*da0073e9SAndroid Build Coastguard Worker            ReduceLROnPlateau,
2388*da0073e9SAndroid Build Coastguard Worker            partial(CyclicLR, base_lr=0.01, max_lr=0.1),
2389*da0073e9SAndroid Build Coastguard Worker            partial(OneCycleLR, max_lr=0.01, total_steps=10, anneal_strategy="linear"),
2390*da0073e9SAndroid Build Coastguard Worker            partial(CosineAnnealingWarmRestarts, T_0=20),
2391*da0073e9SAndroid Build Coastguard Worker        ],
2392*da0073e9SAndroid Build Coastguard Worker    )
2393*da0073e9SAndroid Build Coastguard Worker    @parametrize("weights_only", [True, False])
2394*da0073e9SAndroid Build Coastguard Worker    def test_lr_scheduler_state_dict_load(self, LRClass, weights_only):
2395*da0073e9SAndroid Build Coastguard Worker        scheduler = LRClass(self.opt)
2396*da0073e9SAndroid Build Coastguard Worker        state_dict = scheduler.state_dict()
2397*da0073e9SAndroid Build Coastguard Worker
2398*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryFile() as f:
2399*da0073e9SAndroid Build Coastguard Worker            torch.save(state_dict, f)
2400*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
2401*da0073e9SAndroid Build Coastguard Worker            state_dict_loaded = torch.load(f, weights_only=weights_only)
2402*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(state_dict, state_dict_loaded)
2403*da0073e9SAndroid Build Coastguard Worker            # Make sure state_dict can be loaded
2404*da0073e9SAndroid Build Coastguard Worker            scheduler2 = LRClass(self.opt)
2405*da0073e9SAndroid Build Coastguard Worker            scheduler2.load_state_dict(state_dict_loaded)
2406*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scheduler2.state_dict(), state_dict)
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2409*da0073e9SAndroid Build Coastguard Worker        "LRClass",
2410*da0073e9SAndroid Build Coastguard Worker        [
2411*da0073e9SAndroid Build Coastguard Worker            partial(LambdaLR, lr_lambda=lambda e: e // 10),
2412*da0073e9SAndroid Build Coastguard Worker            partial(MultiplicativeLR, lr_lambda=lambda e: 0.95),
2413*da0073e9SAndroid Build Coastguard Worker            partial(StepLR, step_size=30),
2414*da0073e9SAndroid Build Coastguard Worker            partial(MultiStepLR, milestones=[30, 80]),
2415*da0073e9SAndroid Build Coastguard Worker            ConstantLR,
2416*da0073e9SAndroid Build Coastguard Worker            LinearLR,
2417*da0073e9SAndroid Build Coastguard Worker            partial(ExponentialLR, gamma=0.9),
2418*da0073e9SAndroid Build Coastguard Worker            PolynomialLR,
2419*da0073e9SAndroid Build Coastguard Worker            partial(CosineAnnealingLR, T_max=10),
2420*da0073e9SAndroid Build Coastguard Worker            partial(CosineAnnealingWarmRestarts, T_0=20),
2421*da0073e9SAndroid Build Coastguard Worker        ],
2422*da0073e9SAndroid Build Coastguard Worker    )
2423*da0073e9SAndroid Build Coastguard Worker    def test_constant_initial_lr(self, LRClass):
2424*da0073e9SAndroid Build Coastguard Worker        # Test that the initial learning rate is constant
2425*da0073e9SAndroid Build Coastguard Worker        lr = torch.as_tensor(0.1)
2426*da0073e9SAndroid Build Coastguard Worker        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2427*da0073e9SAndroid Build Coastguard Worker        sch = LRClass(opt)
2428*da0073e9SAndroid Build Coastguard Worker
2429*da0073e9SAndroid Build Coastguard Worker        ori_param_groups = copy.deepcopy(opt.param_groups)
2430*da0073e9SAndroid Build Coastguard Worker
2431*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2432*da0073e9SAndroid Build Coastguard Worker            opt.step()
2433*da0073e9SAndroid Build Coastguard Worker            sch.step(i)
2434*da0073e9SAndroid Build Coastguard Worker            lr.multiply_(0.1)
2435*da0073e9SAndroid Build Coastguard Worker            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2436*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2437*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sch.base_lrs, [0.1])
2438*da0073e9SAndroid Build Coastguard Worker
2439*da0073e9SAndroid Build Coastguard Worker    def test_constant_initial_params_cyclelr(self):
2440*da0073e9SAndroid Build Coastguard Worker        # Test that the initial learning rate is constant
2441*da0073e9SAndroid Build Coastguard Worker        lr = torch.as_tensor(0.1)
2442*da0073e9SAndroid Build Coastguard Worker        max_lr = torch.as_tensor(0.2)
2443*da0073e9SAndroid Build Coastguard Worker        base_momentum = torch.as_tensor(0.8)
2444*da0073e9SAndroid Build Coastguard Worker        max_momentum = torch.as_tensor(0.9)
2445*da0073e9SAndroid Build Coastguard Worker        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2446*da0073e9SAndroid Build Coastguard Worker        sch = CyclicLR(
2447*da0073e9SAndroid Build Coastguard Worker            opt,
2448*da0073e9SAndroid Build Coastguard Worker            base_lr=lr,
2449*da0073e9SAndroid Build Coastguard Worker            max_lr=max_lr,
2450*da0073e9SAndroid Build Coastguard Worker            base_momentum=base_momentum,
2451*da0073e9SAndroid Build Coastguard Worker            max_momentum=max_momentum,
2452*da0073e9SAndroid Build Coastguard Worker        )
2453*da0073e9SAndroid Build Coastguard Worker        ori_param_groups = copy.deepcopy(opt.param_groups)
2454*da0073e9SAndroid Build Coastguard Worker
2455*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2456*da0073e9SAndroid Build Coastguard Worker            lr.multiply_(0.5)
2457*da0073e9SAndroid Build Coastguard Worker            max_lr.multiply_(0.5)
2458*da0073e9SAndroid Build Coastguard Worker            base_momentum.multiply_(0.5)
2459*da0073e9SAndroid Build Coastguard Worker            max_momentum.multiply_(0.5)
2460*da0073e9SAndroid Build Coastguard Worker            opt.step()
2461*da0073e9SAndroid Build Coastguard Worker            sch.step(i)
2462*da0073e9SAndroid Build Coastguard Worker            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2463*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2464*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["max_momentum"], ori_group["max_momentum"])
2465*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["base_momentum"], ori_group["base_momentum"])
2466*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sch.base_lrs, [0.1])
2467*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sch.max_lrs, [0.2])
2468*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["max_momentum"], 0.9)
2469*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["base_momentum"], 0.8)
2470*da0073e9SAndroid Build Coastguard Worker
2471*da0073e9SAndroid Build Coastguard Worker    def test_constant_initial_params_onecyclelr(self):
2472*da0073e9SAndroid Build Coastguard Worker        # Test that the initial learning rate is constant
2473*da0073e9SAndroid Build Coastguard Worker        lr = torch.as_tensor(0.1)
2474*da0073e9SAndroid Build Coastguard Worker        base_momentum = torch.as_tensor(0.85)
2475*da0073e9SAndroid Build Coastguard Worker        max_momentum = torch.as_tensor(0.95)
2476*da0073e9SAndroid Build Coastguard Worker        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2477*da0073e9SAndroid Build Coastguard Worker        sch = OneCycleLR(
2478*da0073e9SAndroid Build Coastguard Worker            opt,
2479*da0073e9SAndroid Build Coastguard Worker            max_lr=lr,
2480*da0073e9SAndroid Build Coastguard Worker            total_steps=10,
2481*da0073e9SAndroid Build Coastguard Worker            base_momentum=base_momentum,
2482*da0073e9SAndroid Build Coastguard Worker            max_momentum=max_momentum,
2483*da0073e9SAndroid Build Coastguard Worker        )
2484*da0073e9SAndroid Build Coastguard Worker        ori_param_groups = copy.deepcopy(opt.param_groups)
2485*da0073e9SAndroid Build Coastguard Worker
2486*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2487*da0073e9SAndroid Build Coastguard Worker            lr.multiply_(0.5)
2488*da0073e9SAndroid Build Coastguard Worker            base_momentum.multiply_(0.5)
2489*da0073e9SAndroid Build Coastguard Worker            max_momentum.multiply_(0.5)
2490*da0073e9SAndroid Build Coastguard Worker            opt.step()
2491*da0073e9SAndroid Build Coastguard Worker            sch.step(i)
2492*da0073e9SAndroid Build Coastguard Worker
2493*da0073e9SAndroid Build Coastguard Worker            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2494*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2495*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["max_lr"], ori_group["max_lr"])
2496*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["min_lr"], ori_group["min_lr"])
2497*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["max_momentum"], ori_group["max_momentum"])
2498*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["base_momentum"], ori_group["base_momentum"])
2499*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["max_momentum"], 0.95)
2500*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["base_momentum"], 0.85)
2501*da0073e9SAndroid Build Coastguard Worker
2502*da0073e9SAndroid Build Coastguard Worker    def test_constant_initial_params_swalr(self):
2503*da0073e9SAndroid Build Coastguard Worker        # Test that the initial learning rate is constant
2504*da0073e9SAndroid Build Coastguard Worker        lr = torch.as_tensor(0.1)
2505*da0073e9SAndroid Build Coastguard Worker        swa_lr = torch.as_tensor(0.05)
2506*da0073e9SAndroid Build Coastguard Worker        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2507*da0073e9SAndroid Build Coastguard Worker        sch = SWALR(opt, swa_lr=swa_lr)
2508*da0073e9SAndroid Build Coastguard Worker        ori_param_groups = copy.deepcopy(opt.param_groups)
2509*da0073e9SAndroid Build Coastguard Worker
2510*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2511*da0073e9SAndroid Build Coastguard Worker            lr.multiply_(0.5)
2512*da0073e9SAndroid Build Coastguard Worker            swa_lr.multiply_(0.5)
2513*da0073e9SAndroid Build Coastguard Worker            opt.step()
2514*da0073e9SAndroid Build Coastguard Worker            sch.step()
2515*da0073e9SAndroid Build Coastguard Worker            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2516*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2517*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["swa_lr"], ori_group["swa_lr"])
2518*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(group["swa_lr"], 0.05)
2519*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sch.base_lrs, [0.1])
2520*da0073e9SAndroid Build Coastguard Worker
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestLRScheduler)
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker
2525*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2526*da0073e9SAndroid Build Coastguard Worker    print("These tests should be run through test/test_optim.py instead")
2527