xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_module_api.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 
5 #include <ATen/core/qualified_name.h>
6 #include <torch/csrc/jit/api/module.h>
7 #include <torch/csrc/jit/frontend/resolver.h>
8 #include <torch/csrc/jit/serialization/import.h>
9 #include <torch/csrc/jit/serialization/import_source.h>
10 #include <torch/csrc/jit/testing/file_check.h>
11 #include <torch/torch.h>
12 
13 namespace torch {
14 namespace jit {
15 
16 static constexpr c10::string_view moduleInterfaceSrc = R"JIT(
17 class OneInterface(ModuleInterface):
18     def one(self, x: Tensor, y: Tensor) -> Tensor:
19         pass
20 )JIT";
21 
22 static const std::vector<std::string> subModuleMethodsSrc = {R"JIT(
23 def one(self, x: Tensor, y: Tensor) -> Tensor:
24     return self.attr * x + y + 1
25 
26 def forward(self, x: Tensor) -> Tensor:
27     return self.attr + x
28 )JIT"};
29 
30 static const std::string parentForward = R"JIT(
31 def forward(self, x: Tensor) -> Tensor:
32     return self.subMod1.one(x, x) + self.subMod2.one(x, x)
33 )JIT";
34 
import_libs(std::shared_ptr<CompilationUnit> cu,const std::string & class_name,const std::shared_ptr<Source> & src,const std::vector<at::IValue> & tensor_table)35 static void import_libs(
36     std::shared_ptr<CompilationUnit> cu,
37     const std::string& class_name,
38     const std::shared_ptr<Source>& src,
39     const std::vector<at::IValue>& tensor_table) {
40   SourceImporter si(
41       cu,
42       &tensor_table,
43       [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
44       /*version=*/2);
45   si.loadType(QualifiedName(class_name));
46 }
47 
TEST(ModuleAPITest,MethodRunAsync)48 TEST(ModuleAPITest, MethodRunAsync) {
49   // Module m("m");
50   // m.define(R"(
51   //   def forward(self):
52   //     r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
53   //     r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
54   //     return r1.wait() + r2.wait()
55   // )");
56   std::string filePath(__FILE__);
57   auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
58   // borrow model file from TEST(GraphExecutorTest, runAsync_executor)
59   testModelFile.append("test_interpreter_async.pt");
60   auto m = load(testModelFile);
61 
62   auto counter = 0;
63   std::mutex mtx;
64 
65   auto launcher = [&](std::function<void()> f) {
66     mtx.lock();
67     ++counter;
68     mtx.unlock();
69     at::launch(std::move(f));
70   };
71 
72   auto method = m.get_method("forward");
73 
74   std::vector<IValue> stack;
75   auto kwargs = std::unordered_map<std::string, at::IValue>();
76   auto future = method.run_async(stack, kwargs, launcher);
77 
78   future->wait();
79 
80   // expect 2 forks and 2 wait callbacks being executed on provided taskLauncher
81   // but ivalue::Future would be marked completed and release wait before
82   // finishing all callbacks
83   ASSERT_GE(counter, 2);
84 }
85 
TEST(ModuleAPITest,Clone)86 TEST(ModuleAPITest, Clone) {
87   auto cu = std::make_shared<CompilationUnit>();
88   // creating child module
89   auto child = ClassType::create("child", cu, true);
90   auto attr_name = "attr";
91   child->addAttribute(attr_name, IntType::get());
92   Module c1(cu, child);
93   auto v1 = IValue(2);
94   c1.register_attribute(attr_name, IntType::get(), v1, false);
95   Module c2(cu, child);
96   auto v2 = IValue(3);
97   c2.register_attribute(attr_name, IntType::get(), v2, false);
98 
99   // attach two child module instance to parent that shares
100   // ClassType
101   auto parent = ClassType::create("parent", cu, true);
102   Module p(cu, parent);
103   p.register_attribute("c1", c1.type(), c1._ivalue(), false);
104   p.register_attribute("c2", c2.type(), c2._ivalue(), false);
105 
106   // clone parent
107   Module p2 = p.clone();
108   // check the two child module has the same ClassType
109   ASSERT_EQ(p2.attr("c1").type(), p2.attr("c2").type());
110   // but different instances
111   ASSERT_EQ(Module(p2.attr("c1").toObject()).attr(attr_name).toInt(), 2);
112   ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
113 }
114 
TEST(ModuleAPITest,CloneWithModuleInterface)115 TEST(ModuleAPITest, CloneWithModuleInterface) {
116   auto cu = std::make_shared<CompilationUnit>();
117 
118   // define a initial module with two submods share same interface
119   Module parentMod("parentMod", cu);
120   Module subMod1("subMod1", cu);
121   Module subMod2("subMod2", cu);
122 
123   std::vector<at::IValue> constantTable;
124   import_libs(
125       cu,
126       "__torch__.OneInterface",
127       std::make_shared<Source>(moduleInterfaceSrc),
128       constantTable);
129 
130   auto v1 = IValue(2);
131   subMod1.register_attribute("attr", IntType::get(), v1, false);
132 
133   auto v2 = IValue(4);
134   subMod2.register_attribute("attr", IntType::get(), v2, false);
135 
136   for (const std::string& method : subModuleMethodsSrc) {
137     subMod1.define(method, nativeResolver());
138     subMod2.define(method, nativeResolver());
139   }
140 
141   parentMod.register_attribute(
142       "subMod1",
143       cu->get_interface("__torch__.OneInterface"),
144       subMod1._ivalue());
145   parentMod.register_attribute(
146       "subMod2",
147       cu->get_interface("__torch__.OneInterface"),
148       subMod2._ivalue());
149 
150   parentMod.define(parentForward, nativeResolver());
151 
152   Module clonedMod = parentMod.clone();
153 
154   // clone will copy both type and data, therefore we'll have a
155   // different type
156   ASSERT_NE(clonedMod.type(), parentMod.type());
157 }
158 
TEST(ModuleAPITest,Copy)159 TEST(ModuleAPITest, Copy) {
160   auto cu = std::make_shared<CompilationUnit>();
161   auto cls = ClassType::create("foo.bar", cu, true);
162   auto attr_name = "attr";
163   cls->addAttribute(attr_name, IntType::get());
164   Module m(cu, cls);
165   auto v = IValue(2);
166   m.register_attribute(attr_name, IntType::get(), v, false);
167 
168   Module m2 = m.clone();
169   Module m3 = m.copy();
170 
171   // Make sure copy works
172   ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
173   ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
174 
175   // clone will copy both type and data, therefore we'll have a
176   // different type
177   ASSERT_NE(m.type(), m2.type());
178   // copy only copies data, type is shared
179   ASSERT_EQ(m.type(), m3.type());
180 
181   // change value of copied instance
182   m3.register_attribute(attr_name, IntType::get(), IValue(3), false);
183   // Verify value of original instance doesn't change
184   ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
185   ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
186 }
187 
TEST(ModuleAPITest,DeepCopy)188 TEST(ModuleAPITest, DeepCopy) {
189   auto cu = std::make_shared<CompilationUnit>();
190   auto cls = ClassType::create("foo.bar", cu, true);
191   auto str_attr = "str_attr";
192   auto int_attr = "int_attr";
193   auto tensor_attr = "tensor_attr";
194   auto tensor_list_attr = "tensor_list_attr";
195   cls->addAttribute(int_attr, IntType::get());
196   cls->addAttribute(str_attr, StringType::get());
197   cls->addAttribute(tensor_attr, TensorType::get());
198   cls->addAttribute(tensor_list_attr, ListType::ofTensors());
199   Module m(cu, cls);
200   c10::List<at::Tensor> list({at::rand(5), at::rand(5)});
201   m.setattr(int_attr, IValue(2));
202   m.setattr(str_attr, IValue("str"));
203   m.setattr(tensor_attr, at::randn(5));
204   m.setattr(tensor_list_attr, list);
205 
206   Module m2 = m.deepcopy();
207   Module m3 = m.copy();
208   // Make sure copy works
209   ASSERT_EQ(m2.attr(int_attr).toInt(), 2);
210   ASSERT_EQ(m3.attr(int_attr).toInt(), 2);
211 
212   // Test overlaps
213   ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
214   ASSERT_TRUE(IValue(m3._ivalue()).overlaps(IValue(m._ivalue())));
215 
216   // Both deepcopy and copy will preserve the type
217   ASSERT_EQ(m.type(), m2.type());
218   ASSERT_EQ(m.type(), m3.type());
219 
220   // change int value of copied instances
221   m2.setattr(int_attr, IValue(3));
222   m3.setattr(int_attr, IValue(4));
223 
224   // Verify value of original instance doesn't change
225   ASSERT_EQ(m.attr(int_attr).toInt(), 2);
226   ASSERT_EQ(m2.attr(int_attr).toInt(), 3);
227   ASSERT_EQ(m3.attr(int_attr).toInt(), 4);
228 
229   // change Tensor value of copied instances
230   at::Tensor t1 = m.attr(tensor_attr).toTensor();
231   at::Tensor t2 =
232       m2.attr(tensor_attr).toTensor(); // deepcopy will copy the Tensor
233   at::Tensor t3 =
234       m3.attr(tensor_attr).toTensor(); // copy will not copy the Tensor
235   // check copy works
236   ASSERT_TRUE(t1.equal(t2));
237   ASSERT_TRUE(t1.equal(t3));
238 
239   // zero out t1
240   t1.zero_();
241   // check that t2 is not affected because it is a deep copy
242   ASSERT_TRUE(!t1.equal(t2));
243   // check that t3 is the same as t1 since it is a shallow copy
244   ASSERT_TRUE(t1.equal(t3));
245 }
246 
TEST(ModuleAPITest,DeepCopyString)247 TEST(ModuleAPITest, DeepCopyString) {
248   auto cu = std::make_shared<CompilationUnit>();
249   auto cls = ClassType::create("foo.bar", cu, true);
250   auto attr1 = "attr1";
251   cls->addAttribute(attr1, StringType::get());
252   std::string str = "str";
253   Module m(cu, cls);
254   m.setattr(attr1, str);
255   auto copied = m.deepcopy();
256   auto original_str = str;
257   ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
258   // check string mutation is not reflected in the copied module
259   str += "str";
260   ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
261 }
262 
TEST(ModuleAPITest,DeepCopyEnum)263 TEST(ModuleAPITest, DeepCopyEnum) {
264   auto cu = std::make_shared<CompilationUnit>();
265   auto cls = ClassType::create("foo.bar", cu, true);
266   auto enum_attr = "enum_attr";
267   auto int_enum_type = EnumType::create(
268       "enum_class",
269       IntType::get(),
270       {{"enum_name_1", 1}, {"enum_name_2", 2}},
271       cu);
272   cls->addAttribute(enum_attr, int_enum_type);
273   Module m(cu, cls);
274   m.setattr(
275       enum_attr,
276       IValue(c10::make_intrusive<ivalue::EnumHolder>(
277           int_enum_type, "enum_name_1", 1)));
278   Module m2 = m.deepcopy();
279 
280   // Make sure deepcopy works
281   c10::ivalue::EnumHolder* m2_holder = m2.attr(enum_attr).toEnumHolder().get();
282   ASSERT_EQ(m2_holder->value().toInt(), 1);
283   ASSERT_EQ(m2_holder->name(), "enum_name_1");
284   ASSERT_EQ(m2_holder->type(), int_enum_type);
285 
286   // Test overlaps
287   ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
288 
289   // Deepcopy will preserve the type
290   ASSERT_EQ(m.type(), m2.type());
291 
292   // Change original, should not affect deepcopy
293   m.setattr(
294       enum_attr,
295       IValue(c10::make_intrusive<ivalue::EnumHolder>(
296           int_enum_type, "enum_name_2", 2)));
297   ASSERT_NE(
298       m.attr(enum_attr).toEnumHolder().get()->value().toInt(),
299       m2.attr(enum_attr).toEnumHolder().get()->value().toInt());
300 }
301 
TEST(ModuleAPITest,DeepCopyPreservesAliasing)302 TEST(ModuleAPITest, DeepCopyPreservesAliasing) {
303   // check deepcopy preserves aliasing
304   auto cu = std::make_shared<CompilationUnit>();
305   auto cls = ClassType::create("foo.bar", cu, true);
306   auto attr1 = "attr1";
307   auto attr2 = "attr2";
308   auto attr3 = "attr3";
309   auto attr4 = "attr4";
310   cls->addAttribute(attr1, ListType::ofTensors());
311   cls->addAttribute(attr2, ListType::ofTensors());
312   cls->addAttribute(attr3, TensorType::get());
313   cls->addAttribute(attr4, TensorType::get());
314   Module m(cu, cls);
315   auto t1 = at::rand(5);
316   auto t2 = at::rand(5);
317   auto t3 = at::rand(5);
318   auto t4 = at::rand({5, 2});
319   c10::List<at::Tensor> list1({t1, t2});
320   c10::List<at::Tensor> list2({t1, t3});
321   // first element of attr1 and attr2 are aliased
322   m.setattr(attr1, list1);
323   m.setattr(attr2, list2);
324   m.setattr(attr3, t4);
325   m.setattr(attr4, t4.view(-1));
326 
327   auto copied = m.deepcopy();
328   // test tensor aliasing
329   auto copied_attr1_t1 = copied.attr(attr1).toList().get(0);
330   auto copied_attr2_t1 = copied.attr(attr2).toList().get(0);
331   ASSERT_TRUE(copied_attr1_t1.isAliasOf(copied_attr2_t1));
332 
333   // test aliasing from view
334   auto copied_attr3 = copied.attr(attr3);
335   auto copied_attr4 = copied.attr(attr3);
336   ASSERT_TRUE(copied_attr3.isAliasOf(copied_attr4));
337 }
338 
TEST(ModuleAPITest,Constants)339 TEST(ModuleAPITest, Constants) {
340   auto cu = std::make_shared<CompilationUnit>();
341   auto cls = ClassType::create("foo.bar", cu, true);
342   auto attr_name = "attr";
343   auto const_name = "const";
344   cls->addAttribute(attr_name, IntType::get());
345   cls->addConstant(const_name, IValue(3));
346   Module m(cu, cls);
347   auto v = IValue(2);
348   m.register_attribute(attr_name, IntType::get(), v, false);
349   ASSERT_TRUE(m.hasattr(attr_name));
350   ASSERT_TRUE(m.hasattr(const_name));
351   ASSERT_EQ(m.attr(attr_name).toInt(), 2);
352   ASSERT_EQ(m.attr(const_name).toInt(), 3);
353 }
354 
TEST(ModuleAPITest,Parameters)355 TEST(ModuleAPITest, Parameters) {
356   auto cu = std::make_shared<CompilationUnit>();
357   auto cls = ClassType::create("foo.bar", cu, true);
358   Module m(cu, cls);
359   // Tensor parameter
360   m.register_parameter(
361       "tensor_param", at::empty({3}, at::kFloat), /* is_buffer */ false);
362   // None parameter
363   m.register_attribute(
364       "none_param", NoneType::get(), IValue(), /* is_param */ true);
365   m.register_attribute(
366       "none_param2", NoneType::get(), IValue(), /* is_param */ true);
367   auto param_list = m.parameters();
368   ASSERT_EQ(param_list.size(), 1);
369   ASSERT_TRUE(m.hasattr("tensor_param"));
370   ASSERT_TRUE(m.hasattr("none_param"));
371   ASSERT_TRUE(m.hasattr("none_param2"));
372 }
373 
TEST(ModuleAPITest,Define)374 TEST(ModuleAPITest, Define) {
375   Module m("m");
376   m.register_parameter("foo", torch::ones({}), false);
377   m.define(R"(
378     def add_it(self, x, b : int = 4):
379       return self.foo + x + b
380   )");
381   auto result = m.run_method("add_it", torch::ones({}));
382   AT_ASSERT(result.toTensor().item<float>() == 6);
383 }
384 
TEST(ModuleAPITest,Freezing)385 TEST(ModuleAPITest, Freezing) {
386   Module m("m");
387   m.register_parameter("foo", torch::ones({}), false);
388   m.define(R"(
389     def forward(self, x, b : int = 4):
390       return self.foo + x + b
391   )");
392   m.eval();
393   auto forward_g = m.get_method("forward").graph();
394   testing::FileCheck().check("GetAttr")->run(*forward_g);
395 
396   // Removal of GetAttr is done by freezing
397   auto frozen_mod = torch::jit::freeze(m);
398   forward_g = frozen_mod.get_method("forward").graph();
399   testing::FileCheck().check_not("GetAttr")->run(*forward_g);
400 
401   // If no training mode is set, the module is NOT frozen by OFI
402   auto frozen_mod2 = torch::jit::optimize_for_inference(m);
403   forward_g = frozen_mod2.get_method("forward").graph();
404   testing::FileCheck().check("GetAttr")->run(*forward_g);
405 }
406 
TEST(ModuleAPITest,OfiFreezesTraining)407 TEST(ModuleAPITest, OfiFreezesTraining) {
408   Module m("m");
409   m.register_parameter("foo", torch::ones({}), false);
410   m.define(R"(
411     def forward(self, x, b : int = 4):
412       return self.foo + x + b
413   )");
414   m.register_attribute("training", BoolType::get(), true);
415   m.eval();
416 
417   // Before freezing, we have a GetAttr check
418   auto forward_g = m.get_method("forward").graph();
419   testing::FileCheck().check("GetAttr")->run(*forward_g);
420 
421   // Demonstrate that freezing happens when OFI is called
422   // Removal of GetAttr is done by freezing, but only when training
423   // attribute is set
424   auto frozen_mod = torch::jit::optimize_for_inference(m);
425   forward_g = frozen_mod.get_method("forward").graph();
426   testing::FileCheck().check_not("GetAttr")->run(*forward_g);
427 }
428 
TEST(ModuleAPITest,OfiFreezesNoForward)429 TEST(ModuleAPITest, OfiFreezesNoForward) {
430   Module m("m");
431   m.register_parameter("foo", torch::ones({}), false);
432   m.define(R"(
433     def bar(self, x, b : int = 4):
434       return self.foo + x + b
435   )");
436   m.eval();
437 
438   // OFI is called without the presence of forward methods
439   auto frozen_mod =
440       torch::jit::optimize_for_inference(m, std::vector<std::string>{"bar"});
441   ASSERT_EQ(
442       m.run_method("bar", torch::ones({})).toTensor().item<float>(),
443       frozen_mod.run_method("bar", torch::ones({})).toTensor().item<float>());
444 }
445 
TEST(ModuleAPITest,To_CUDA)446 TEST(ModuleAPITest, To_CUDA) {
447   Module m("test");
448   {
449     // test cuda to cpu for params and buffers
450     m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
451     m.register_buffer("bar", torch::ones({}, at::kCUDA));
452 
453     m.to(at::kCUDA);
454     m.to(at::kCPU);
455     AT_ASSERT(m.attr("foo").toTensor().device().is_cpu());
456     AT_ASSERT(m.attr("bar").toTensor().device().is_cpu());
457   }
458   {
459     // test cpu to cuda for params and buffers
460     m.register_parameter("foo", torch::ones({}), false);
461     m.register_buffer("bar", torch::ones({}));
462 
463     m.to(at::kCUDA);
464     AT_ASSERT(m.attr("foo").toTensor().device().is_cuda());
465     AT_ASSERT(m.attr("bar").toTensor().device().is_cuda());
466   }
467 }
468 
469 } // namespace jit
470 } // namespace torch
471