xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/adagrad.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/adagrad.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/optim/serialize.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7 
8 #include <ATen/ATen.h>
9 #include <c10/util/irange.h>
10 
11 #include <functional>
12 
13 namespace torch {
14 namespace optim {
15 
AdagradOptions(double lr)16 AdagradOptions::AdagradOptions(double lr) : lr_(lr) {}
17 
operator ==(const AdagradOptions & lhs,const AdagradOptions & rhs)18 bool operator==(const AdagradOptions& lhs, const AdagradOptions& rhs) {
19   return (lhs.lr() == rhs.lr()) && (lhs.lr_decay() == rhs.lr_decay()) &&
20       (lhs.weight_decay() == rhs.weight_decay()) &&
21       (lhs.initial_accumulator_value() == rhs.initial_accumulator_value()) &&
22       (lhs.eps() == rhs.eps());
23 }
24 
serialize(torch::serialize::OutputArchive & archive) const25 void AdagradOptions::serialize(torch::serialize::OutputArchive& archive) const {
26   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
27   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr_decay);
28   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
29   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(initial_accumulator_value);
30   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
31 }
32 
serialize(torch::serialize::InputArchive & archive)33 void AdagradOptions::serialize(torch::serialize::InputArchive& archive) {
34   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
35   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr_decay);
36   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
37   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, initial_accumulator_value);
38   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
39 }
40 
get_lr() const41 double AdagradOptions::get_lr() const {
42   return lr();
43 }
44 
set_lr(const double lr)45 void AdagradOptions::set_lr(const double lr) {
46   this->lr(lr);
47 }
48 
operator ==(const AdagradParamState & lhs,const AdagradParamState & rhs)49 bool operator==(const AdagradParamState& lhs, const AdagradParamState& rhs) {
50   return (lhs.step() == rhs.step()) && torch::equal(lhs.sum(), rhs.sum());
51 }
52 
serialize(torch::serialize::OutputArchive & archive) const53 void AdagradParamState::serialize(
54     torch::serialize::OutputArchive& archive) const {
55   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
56   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(sum);
57 }
58 
serialize(torch::serialize::InputArchive & archive)59 void AdagradParamState::serialize(torch::serialize::InputArchive& archive) {
60   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
61   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, sum);
62 }
63 
64 /// Adapted from
65 /// https://github.com/pytorch/pytorch/blob/master/torch/optim/adagrad.py
step(LossClosure closure)66 Tensor Adagrad::step(LossClosure closure) {
67   NoGradGuard no_grad;
68   Tensor loss = {};
69   if (closure != nullptr) {
70     at::AutoGradMode enable_grad(true);
71     loss = closure();
72   }
73   for (auto& group : param_groups_) {
74     for (auto& p : group.params()) {
75       if (!p.grad().defined()) {
76         continue;
77       }
78       auto grad = p.grad();
79       TORCH_INTERNAL_ASSERT(
80           state_[p.unsafeGetTensorImpl()] != nullptr,
81           "state found NULL for the Tensor ",
82           p);
83       auto& state =
84           static_cast<AdagradParamState&>(*state_[p.unsafeGetTensorImpl()]);
85       auto& options = static_cast<AdagradOptions&>(group.options());
86 
87       state.step(state.step() + 1);
88 
89       if (options.weight_decay() != 0) {
90         TORCH_CHECK(
91             !p.grad().is_sparse(),
92             "weight_decay option is not compatible with sparse gradients");
93         grad = grad.add(p, options.weight_decay());
94       }
95       const auto clr = options.lr() /
96           (1 + static_cast<double>(state.step() - 1) * options.lr_decay());
97 
98       if (grad.is_sparse()) {
99         grad = grad.coalesce();
100         auto grad_indices = grad._indices();
101         auto grad_values = grad._values();
102         auto size = grad.sizes();
103 
104         auto make_sparse = [&](const Tensor& values) -> Tensor {
105           if (grad_indices.dim() == 0 || values.dim() == 0) {
106             return torch::empty({0}, grad.options()).resize_as_(grad);
107           }
108           return torch::sparse_coo_tensor(
109               grad_indices, values, size, grad.options());
110         };
111         state.sum(state.sum().add_(make_sparse(grad_values.pow(2))));
112         auto std = state.sum().sparse_mask(grad);
113         const auto std_values = std._values().sqrt_().add_(options.eps());
114 
115         p.add_(make_sparse(grad_values / std_values), -clr);
116       } else {
117         state.sum(state.sum().addcmul_(grad, grad, 1.0));
118         const auto std = state.sum().sqrt().add_(options.eps());
119         p.addcdiv_(grad, std, -clr);
120       }
121     }
122   }
123   return loss;
124 }
125 
save(serialize::OutputArchive & archive) const126 void Adagrad::save(serialize::OutputArchive& archive) const {
127   serialize(*this, archive);
128 }
129 
load(serialize::InputArchive & archive)130 void Adagrad::load(serialize::InputArchive& archive) {
131   IValue pytorch_version;
132   if (archive.try_read("pytorch_version", pytorch_version)) {
133     serialize(*this, archive);
134   } else { // deserializing archives saved in old format (prior to
135            // version 1.5.0)
136     TORCH_WARN(
137         "Your serialized Adagrad optimizer is still using the old serialization format. "
138         "You should re-save your Adagrad optimizer to use the new serialization format.");
139     std::vector<Tensor> sum_buffers;
140     std::vector<int64_t> step_buffers;
141     torch::optim::serialize(archive, "sum_buffers", sum_buffers);
142     torch::optim::serialize(archive, "step_buffers", step_buffers);
143     // since there were no param_groups prior to version 1.5.0, assuming all
144     // tensors are now in one param_group
145     std::vector<Tensor> params = param_groups_.at(0).params();
146     for (const auto idx : c10::irange(params.size())) {
147       auto state = std::make_unique<AdagradParamState>();
148       state->step(step_buffers[idx]);
149       state->sum(sum_buffers[idx]);
150       state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
151     }
152   }
153 }
154 } // namespace optim
155 } // namespace torch
156