1 *da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2 *da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
3 *da0073e9SAndroid Build Coastguard Worker #include <algorithm>
4 *da0073e9SAndroid Build Coastguard Worker #include <memory>
5 *da0073e9SAndroid Build Coastguard Worker #include <vector>
6 *da0073e9SAndroid Build Coastguard Worker
7 *da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
8 *da0073e9SAndroid Build Coastguard Worker
9 *da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
10 *da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
11 *da0073e9SAndroid Build Coastguard Worker
12 *da0073e9SAndroid Build Coastguard Worker struct ModuleDictTest : torch::test::SeedingFixture {};
13 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,ConstructsFromList)14 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, ConstructsFromList) {
15 *da0073e9SAndroid Build Coastguard Worker struct M : Module {
16 *da0073e9SAndroid Build Coastguard Worker explicit M(int value_) : value(value_) {}
17 *da0073e9SAndroid Build Coastguard Worker int value;
18 *da0073e9SAndroid Build Coastguard Worker };
19 *da0073e9SAndroid Build Coastguard Worker
20 *da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
21 *da0073e9SAndroid Build Coastguard Worker {"module_1", std::make_shared<M>(1)},
22 *da0073e9SAndroid Build Coastguard Worker {"module_2", std::make_shared<M>(2)},
23 *da0073e9SAndroid Build Coastguard Worker {"module_3", std::make_shared<M>(3)}};
24 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(list);
25 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 3);
26 *da0073e9SAndroid Build Coastguard Worker }
27 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,ConstructsFromordereddict)28 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, ConstructsFromordereddict) {
29 *da0073e9SAndroid Build Coastguard Worker struct M : Module {
30 *da0073e9SAndroid Build Coastguard Worker explicit M(int value_) : value(value_) {}
31 *da0073e9SAndroid Build Coastguard Worker int value;
32 *da0073e9SAndroid Build Coastguard Worker };
33 *da0073e9SAndroid Build Coastguard Worker
34 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
35 *da0073e9SAndroid Build Coastguard Worker {"module_1", std::make_shared<M>(1)},
36 *da0073e9SAndroid Build Coastguard Worker {"module_2", std::make_shared<M>(2)},
37 *da0073e9SAndroid Build Coastguard Worker {"module_3", std::make_shared<M>(3)},
38 *da0073e9SAndroid Build Coastguard Worker };
39 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
40 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 3);
41 *da0073e9SAndroid Build Coastguard Worker }
42 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,UpdatePopClearContains)43 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, UpdatePopClearContains) {
44 *da0073e9SAndroid Build Coastguard Worker struct M : Module {
45 *da0073e9SAndroid Build Coastguard Worker explicit M(int value_) : value(value_) {}
46 *da0073e9SAndroid Build Coastguard Worker int value;
47 *da0073e9SAndroid Build Coastguard Worker };
48 *da0073e9SAndroid Build Coastguard Worker
49 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict;
50 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(dict->empty());
51 *da0073e9SAndroid Build Coastguard Worker // Update by List
52 *da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
53 *da0073e9SAndroid Build Coastguard Worker {"module_1", std::make_shared<M>(1)}};
54 *da0073e9SAndroid Build Coastguard Worker dict->update(list1);
55 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 1);
56 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(dict->contains("module_1"));
57 *da0073e9SAndroid Build Coastguard Worker // Update by OrderedDict
58 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
59 *da0073e9SAndroid Build Coastguard Worker {"module_2", std::make_shared<M>(2)}};
60 *da0073e9SAndroid Build Coastguard Worker dict->update(ordereddict);
61 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 2);
62 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(dict->contains("module_2"));
63 *da0073e9SAndroid Build Coastguard Worker // Update by another ModuleDict
64 *da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
65 *da0073e9SAndroid Build Coastguard Worker {"module_3", std::make_shared<M>(3)}};
66 *da0073e9SAndroid Build Coastguard Worker ModuleDict updatedict(list2);
67 *da0073e9SAndroid Build Coastguard Worker dict->update(*updatedict);
68 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 3);
69 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(dict->contains("module_3"));
70 *da0073e9SAndroid Build Coastguard Worker // Pop
71 *da0073e9SAndroid Build Coastguard Worker dict->pop("module_1");
72 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 2);
73 *da0073e9SAndroid Build Coastguard Worker // Pop unexist
74 *da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined");
75 *da0073e9SAndroid Build Coastguard Worker // Clear
76 *da0073e9SAndroid Build Coastguard Worker dict->clear();
77 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 0);
78 *da0073e9SAndroid Build Coastguard Worker }
79 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,UpdateExist)80 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, UpdateExist) {
81 *da0073e9SAndroid Build Coastguard Worker struct M : Module {
82 *da0073e9SAndroid Build Coastguard Worker explicit M(int value_) : value(value_) {}
83 *da0073e9SAndroid Build Coastguard Worker int value;
84 *da0073e9SAndroid Build Coastguard Worker };
85 *da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
86 *da0073e9SAndroid Build Coastguard Worker {"module_1", std::make_shared<M>(1)},
87 *da0073e9SAndroid Build Coastguard Worker {"module_2", std::make_shared<M>(2)}};
88 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(list1);
89 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->at<M>("module_2").value, 2);
90 *da0073e9SAndroid Build Coastguard Worker // Update by list
91 *da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
92 *da0073e9SAndroid Build Coastguard Worker {"module_2", std::make_shared<M>(0)},
93 *da0073e9SAndroid Build Coastguard Worker {"module_3", std::make_shared<M>(3)}};
94 *da0073e9SAndroid Build Coastguard Worker dict->update(list2);
95 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 3);
96 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->at<M>("module_2").value, 0);
97 *da0073e9SAndroid Build Coastguard Worker // Update by ordereddict
98 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
99 *da0073e9SAndroid Build Coastguard Worker {"module_3", std::make_shared<M>(0)},
100 *da0073e9SAndroid Build Coastguard Worker {"module_4", std::make_shared<M>(4)}};
101 *da0073e9SAndroid Build Coastguard Worker dict->update(ordereddict);
102 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 4);
103 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->at<M>("module_3").value, 0);
104 *da0073e9SAndroid Build Coastguard Worker // Update by ModuleDict
105 *da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = {
106 *da0073e9SAndroid Build Coastguard Worker {"module_4", std::make_shared<M>(0)},
107 *da0073e9SAndroid Build Coastguard Worker {"module_1", std::make_shared<M>(0)}};
108 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict2(list3);
109 *da0073e9SAndroid Build Coastguard Worker dict->update(*dict2);
110 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), 4);
111 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->at<M>("module_1").value, 0);
112 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->at<M>("module_4").value, 0);
113 *da0073e9SAndroid Build Coastguard Worker }
114 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,Keys)115 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, Keys) {
116 *da0073e9SAndroid Build Coastguard Worker struct M : Module {
117 *da0073e9SAndroid Build Coastguard Worker explicit M(int value_) : value(value_) {}
118 *da0073e9SAndroid Build Coastguard Worker int value;
119 *da0073e9SAndroid Build Coastguard Worker };
120 *da0073e9SAndroid Build Coastguard Worker
121 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
122 *da0073e9SAndroid Build Coastguard Worker {"linear", Linear(10, 3).ptr()},
123 *da0073e9SAndroid Build Coastguard Worker {"conv", Conv2d(1, 2, 3).ptr()},
124 *da0073e9SAndroid Build Coastguard Worker {"dropout", Dropout(0.5).ptr()},
125 *da0073e9SAndroid Build Coastguard Worker };
126 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
127 *da0073e9SAndroid Build Coastguard Worker const auto& keys = dict->keys();
128 *da0073e9SAndroid Build Coastguard Worker std::vector<std::string> expected{"linear", "conv", "dropout"};
129 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(keys, expected);
130 *da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined");
131 *da0073e9SAndroid Build Coastguard Worker
132 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(dict["linear"]->as<Linear>());
133 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(dict["conv"]->as<Conv2d>());
134 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(dict["dropout"]->as<Dropout>());
135 *da0073e9SAndroid Build Coastguard Worker }
136 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,Values)137 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, Values) {
138 *da0073e9SAndroid Build Coastguard Worker struct M : Module {
139 *da0073e9SAndroid Build Coastguard Worker explicit M(int value_) : value(value_) {}
140 *da0073e9SAndroid Build Coastguard Worker int value;
141 *da0073e9SAndroid Build Coastguard Worker };
142 *da0073e9SAndroid Build Coastguard Worker
143 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
144 *da0073e9SAndroid Build Coastguard Worker {"module_1", std::make_shared<M>(1)},
145 *da0073e9SAndroid Build Coastguard Worker {"module_2", std::make_shared<M>(2)},
146 *da0073e9SAndroid Build Coastguard Worker };
147 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
148 *da0073e9SAndroid Build Coastguard Worker const auto& values = dict->values();
149 *da0073e9SAndroid Build Coastguard Worker const auto& expected = ordereddict.values();
150 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(values, expected);
151 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(std::equal(
152 *da0073e9SAndroid Build Coastguard Worker dict->begin(),
153 *da0073e9SAndroid Build Coastguard Worker dict->end(),
154 *da0073e9SAndroid Build Coastguard Worker ordereddict.begin(),
155 *da0073e9SAndroid Build Coastguard Worker [](const auto& lhs, const auto& rhs) {
156 *da0073e9SAndroid Build Coastguard Worker return lhs.value().get() == rhs.value().get();
157 *da0073e9SAndroid Build Coastguard Worker }));
158 *da0073e9SAndroid Build Coastguard Worker }
159 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,SanityCheckForHoldingStandardModules)160 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) {
161 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
162 *da0073e9SAndroid Build Coastguard Worker {"linear", Linear(10, 3).ptr()},
163 *da0073e9SAndroid Build Coastguard Worker {"conv", Conv2d(1, 2, 3).ptr()},
164 *da0073e9SAndroid Build Coastguard Worker {"dropout", Dropout(0.5).ptr()},
165 *da0073e9SAndroid Build Coastguard Worker {"batch", BatchNorm2d(5).ptr()},
166 *da0073e9SAndroid Build Coastguard Worker {"embedding", Embedding(4, 10).ptr()},
167 *da0073e9SAndroid Build Coastguard Worker {"lstm", LSTM(4, 5).ptr()}};
168 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
169 *da0073e9SAndroid Build Coastguard Worker }
170 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,HasReferenceSemantics)171 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, HasReferenceSemantics) {
172 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
173 *da0073e9SAndroid Build Coastguard Worker {"linear1", Linear(2, 3).ptr()},
174 *da0073e9SAndroid Build Coastguard Worker {"linear2", Linear(3, 4).ptr()},
175 *da0073e9SAndroid Build Coastguard Worker {"linear3", Linear(4, 5).ptr()},
176 *da0073e9SAndroid Build Coastguard Worker };
177 *da0073e9SAndroid Build Coastguard Worker ModuleDict first(ordereddict);
178 *da0073e9SAndroid Build Coastguard Worker ModuleDict second(ordereddict);
179 *da0073e9SAndroid Build Coastguard Worker
180 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(first->size(), second->size());
181 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(std::equal(
182 *da0073e9SAndroid Build Coastguard Worker first->begin(),
183 *da0073e9SAndroid Build Coastguard Worker first->end(),
184 *da0073e9SAndroid Build Coastguard Worker second->begin(),
185 *da0073e9SAndroid Build Coastguard Worker [](const auto& lhs, const auto& rhs) {
186 *da0073e9SAndroid Build Coastguard Worker return lhs.value().get() == rhs.value().get();
187 *da0073e9SAndroid Build Coastguard Worker }));
188 *da0073e9SAndroid Build Coastguard Worker }
189 *da0073e9SAndroid Build Coastguard Worker
iscloneable_helper(torch::Device device)190 *da0073e9SAndroid Build Coastguard Worker void iscloneable_helper(torch::Device device) {
191 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
192 *da0073e9SAndroid Build Coastguard Worker {"linear", Linear(2, 3).ptr()},
193 *da0073e9SAndroid Build Coastguard Worker {"relu", Functional(torch::relu).ptr()},
194 *da0073e9SAndroid Build Coastguard Worker {"batch", BatchNorm1d(3).ptr()},
195 *da0073e9SAndroid Build Coastguard Worker };
196 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
197 *da0073e9SAndroid Build Coastguard Worker dict->to(device);
198 *da0073e9SAndroid Build Coastguard Worker ModuleDict clone =
199 *da0073e9SAndroid Build Coastguard Worker std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
200 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(dict->size(), clone->size());
201 *da0073e9SAndroid Build Coastguard Worker
202 *da0073e9SAndroid Build Coastguard Worker for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end();
203 *da0073e9SAndroid Build Coastguard Worker ++it, ++it_c) {
204 *da0073e9SAndroid Build Coastguard Worker // The key should be same
205 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(it->key(), it_c->key());
206 *da0073e9SAndroid Build Coastguard Worker // The modules should be the same kind (type).
207 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(it->value()->name(), it_c->value()->name());
208 *da0073e9SAndroid Build Coastguard Worker // But not pointer-equal (distinct objects).
209 *da0073e9SAndroid Build Coastguard Worker ASSERT_NE(it->value(), it_c->value());
210 *da0073e9SAndroid Build Coastguard Worker }
211 *da0073e9SAndroid Build Coastguard Worker
212 *da0073e9SAndroid Build Coastguard Worker // Verify that the clone is deep, i.e. parameters of modules are cloned too.
213 *da0073e9SAndroid Build Coastguard Worker torch::NoGradGuard no_grad;
214 *da0073e9SAndroid Build Coastguard Worker
215 *da0073e9SAndroid Build Coastguard Worker auto params1 = dict->named_parameters();
216 *da0073e9SAndroid Build Coastguard Worker auto params2 = clone->named_parameters();
217 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(params1.size(), params2.size());
218 *da0073e9SAndroid Build Coastguard Worker for (auto& param : params1) {
219 *da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
220 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(param->device(), params2[param.key()].device());
221 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(param->allclose(params2[param.key()]));
222 *da0073e9SAndroid Build Coastguard Worker param->add_(2);
223 *da0073e9SAndroid Build Coastguard Worker }
224 *da0073e9SAndroid Build Coastguard Worker for (auto& param : params1) {
225 *da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(param->allclose(params2[param.key()]));
226 *da0073e9SAndroid Build Coastguard Worker }
227 *da0073e9SAndroid Build Coastguard Worker }
228 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,IsCloneable)229 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, IsCloneable) {
230 *da0073e9SAndroid Build Coastguard Worker iscloneable_helper(torch::kCPU);
231 *da0073e9SAndroid Build Coastguard Worker }
232 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,IsCloneable_CUDA)233 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, IsCloneable_CUDA) {
234 *da0073e9SAndroid Build Coastguard Worker iscloneable_helper({torch::kCUDA, 0});
235 *da0073e9SAndroid Build Coastguard Worker }
236 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,RegistersElementsAsSubmodules)237 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) {
238 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = {
239 *da0073e9SAndroid Build Coastguard Worker {"linear", Linear(10, 3).ptr()},
240 *da0073e9SAndroid Build Coastguard Worker {"conv", Conv2d(1, 2, 3).ptr()},
241 *da0073e9SAndroid Build Coastguard Worker {"test", Dropout(0.5).ptr()},
242 *da0073e9SAndroid Build Coastguard Worker };
243 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict1);
244 *da0073e9SAndroid Build Coastguard Worker
245 *da0073e9SAndroid Build Coastguard Worker auto modules = dict->children();
246 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(modules[0]->as<Linear>());
247 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(modules[1]->as<Conv2d>());
248 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(modules[2]->as<Dropout>());
249 *da0073e9SAndroid Build Coastguard Worker
250 *da0073e9SAndroid Build Coastguard Worker // Update Existing
251 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = {
252 *da0073e9SAndroid Build Coastguard Worker {"lstm", LSTM(4, 5).ptr()}, {"test", BatchNorm2d(5).ptr()}};
253 *da0073e9SAndroid Build Coastguard Worker dict->update(ordereddict2);
254 *da0073e9SAndroid Build Coastguard Worker
255 *da0073e9SAndroid Build Coastguard Worker modules = dict->children();
256 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(modules[0]->as<Linear>());
257 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(modules[1]->as<Conv2d>());
258 *da0073e9SAndroid Build Coastguard Worker // Keep Order
259 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(modules[2]->as<BatchNorm2d>());
260 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(modules[3]->as<LSTM>());
261 *da0073e9SAndroid Build Coastguard Worker }
262 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,CloneToDevice_CUDA)263 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, CloneToDevice_CUDA) {
264 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
265 *da0073e9SAndroid Build Coastguard Worker {"linear", Linear(2, 3).ptr()},
266 *da0073e9SAndroid Build Coastguard Worker {"relu", Functional(torch::relu).ptr()},
267 *da0073e9SAndroid Build Coastguard Worker {"batch", BatchNorm1d(3).ptr()},
268 *da0073e9SAndroid Build Coastguard Worker };
269 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
270 *da0073e9SAndroid Build Coastguard Worker torch::Device device(torch::kCUDA, 0);
271 *da0073e9SAndroid Build Coastguard Worker ModuleDict clone =
272 *da0073e9SAndroid Build Coastguard Worker std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
273 *da0073e9SAndroid Build Coastguard Worker for (const auto& p : clone->parameters()) {
274 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(p.device(), device);
275 *da0073e9SAndroid Build Coastguard Worker }
276 *da0073e9SAndroid Build Coastguard Worker for (const auto& b : clone->buffers()) {
277 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(b.device(), device);
278 *da0073e9SAndroid Build Coastguard Worker }
279 *da0073e9SAndroid Build Coastguard Worker }
280 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,PrettyPrintModuleDict)281 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, PrettyPrintModuleDict) {
282 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
283 *da0073e9SAndroid Build Coastguard Worker {"linear", Linear(10, 3).ptr()},
284 *da0073e9SAndroid Build Coastguard Worker {"conv", Conv2d(1, 2, 3).ptr()},
285 *da0073e9SAndroid Build Coastguard Worker {"dropout", Dropout(0.5).ptr()},
286 *da0073e9SAndroid Build Coastguard Worker {"batch", BatchNorm2d(5).ptr()},
287 *da0073e9SAndroid Build Coastguard Worker {"embedding", Embedding(4, 10).ptr()},
288 *da0073e9SAndroid Build Coastguard Worker {"lstm", LSTM(4, 5).ptr()}};
289 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
290 *da0073e9SAndroid Build Coastguard Worker
291 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
292 *da0073e9SAndroid Build Coastguard Worker c10::str(dict),
293 *da0073e9SAndroid Build Coastguard Worker "torch::nn::ModuleDict(\n"
294 *da0073e9SAndroid Build Coastguard Worker " (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
295 *da0073e9SAndroid Build Coastguard Worker " (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
296 *da0073e9SAndroid Build Coastguard Worker " (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
297 *da0073e9SAndroid Build Coastguard Worker " (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
298 *da0073e9SAndroid Build Coastguard Worker " (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
299 *da0073e9SAndroid Build Coastguard Worker " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
300 *da0073e9SAndroid Build Coastguard Worker ")");
301 *da0073e9SAndroid Build Coastguard Worker }
302 *da0073e9SAndroid Build Coastguard Worker
TEST_F(ModuleDictTest,InvalidAt)303 *da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, InvalidAt) {
304 *da0073e9SAndroid Build Coastguard Worker torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
305 *da0073e9SAndroid Build Coastguard Worker {"linear", Linear(10, 3).ptr()}};
306 *da0073e9SAndroid Build Coastguard Worker ModuleDict dict(ordereddict);
307 *da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
308 *da0073e9SAndroid Build Coastguard Worker dict->at<torch::nn::Dropout2dImpl>("linear"), "Unable to cast module");
309 *da0073e9SAndroid Build Coastguard Worker }
310