1 #include <torch/optim/serialize.h>
2
3 #include <torch/serialize/archive.h>
4 #include <torch/types.h>
5
6 #include <cstddef>
7 #include <cstdint>
8 #include <deque>
9 #include <string>
10 #include <vector>
11
12 namespace torch {
13 namespace optim {
serialize(serialize::OutputArchive & archive,const std::string & key,const int64_t & value)14 void serialize(
15 serialize::OutputArchive& archive,
16 const std::string& key,
17 const int64_t& value) {
18 archive.write(key, IValue(value));
19 }
20
serialize(serialize::InputArchive & archive,const std::string & key,int64_t & value)21 void serialize(
22 serialize::InputArchive& archive,
23 const std::string& key,
24 int64_t& value) {
25 IValue ivalue;
26 archive.read(key, ivalue);
27 value = ivalue.toInt();
28 }
29
serialize(serialize::OutputArchive & archive,const std::string & key,const std::vector<int64_t> & steps)30 void serialize(
31 serialize::OutputArchive& archive,
32 const std::string& key,
33 const std::vector<int64_t>& steps) {
34 std::vector<torch::Tensor> tensors;
35 tensors.reserve(steps.size());
36 for (const auto& step : steps) {
37 tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
38 }
39 serialize(archive, key, tensors);
40 }
41
serialize(serialize::InputArchive & archive,const std::string & key,std::vector<int64_t> & steps)42 void serialize(
43 serialize::InputArchive& archive,
44 const std::string& key,
45 std::vector<int64_t>& steps) {
46 steps.clear();
47 std::vector<torch::Tensor> tensors;
48 serialize(archive, key, tensors);
49 for (const auto& step : tensors) {
50 steps.push_back(step.item<int64_t>());
51 }
52 }
53 } // namespace optim
54 } // namespace torch
55