xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <c10/util/irange.h>
2  #include <torch/optim/schedulers/lr_scheduler.h>
3  
4  namespace torch {
5  namespace optim {
6  
LRScheduler(torch::optim::Optimizer & optimizer)7  LRScheduler::LRScheduler(torch::optim::Optimizer& optimizer)
8      : optimizer_(optimizer) {}
9  
step()10  void LRScheduler::step() {
11    std::vector<double> learning_rates = get_lrs();
12    set_optimizer_lrs(learning_rates);
13    step_count_++;
14  }
15  
set_optimizer_lrs(const std::vector<double> & learning_rates)16  void LRScheduler::set_optimizer_lrs(const std::vector<double>& learning_rates) {
17    // Check the number of learning rates is equal to the number of parameters
18    // groups in the optimizer
19    TORCH_CHECK(
20        learning_rates.size() == optimizer_.param_groups().size(),
21        "Number of learning rates not equal to the number of param groups\n",
22        "Number of learning rates given: ",
23        learning_rates.size(),
24        "\nNumber of param groups: ",
25        optimizer_.param_groups().size());
26  
27    for (const auto i : c10::irange(optimizer_.param_groups().size())) {
28      optimizer_.param_groups()[i].options().set_lr(learning_rates[i]);
29    }
30  }
31  
get_current_lrs() const32  std::vector<double> LRScheduler::get_current_lrs() const {
33    std::vector<double> learnings_rates(optimizer_.param_groups().size());
34    if (!learnings_rates.empty()) {
35      for (const auto i : c10::irange(optimizer_.param_groups().size())) {
36        learnings_rates[i] = optimizer_.param_groups()[i].options().get_lr();
37      }
38    }
39    return learnings_rates;
40  }
41  
42  } // namespace optim
43  } // namespace torch
44