xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_irparser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/irparser.h>
5 #include <torch/csrc/jit/testing/file_check.h>
6 
7 #include <sstream>
8 #include <string>
9 
10 namespace torch {
11 namespace jit {
12 
13 /** \brief Parse IR from \p S, print the parsed graph and verify that the output
14  * string matches the original string.
15  *
16  * The function is sensitive to value naming and whitespace, so it should be
17  * used with care. Nevertheless, it helps to keep tests more compact.
18  */
checkRoundtrip(const std::string & s)19 static void checkRoundtrip(const std::string& s) {
20   auto graph = std::make_shared<Graph>();
21   parseIR(s, &*graph);
22   std::ostringstream ss;
23   ss << *graph;
24   std::string parsed = ss.str();
25 
26   // Skip whitespace in the beginning of the input string.
27   int i = 0;
28   for (char c : s) {
29     if (!isspace(c)) {
30       break;
31     }
32     i++;
33   }
34   std::string original = s.substr(i, s.size());
35   if (original != parsed) {
36     std::cerr << "Input:" << std::endl << original << std::endl;
37     std::cerr << "Parsed:" << std::endl << parsed << std::endl;
38   }
39   AT_ASSERT(original == parsed);
40 }
41 
TEST(IRParserTest,Basic)42 TEST(IRParserTest, Basic) {
43   auto graph = std::make_shared<Graph>();
44   std::unordered_map<std::string, Value*> vmap;
45   parseIR(
46       R"IR(
47 graph(%0 : Tensor, %1 : Tensor):
48   %2 : Tensor = foo::add(%0, %1)
49   %res, %3 = foo::mul(%0, %2)
50   %x, %y = foo::combine(%res, %2, %3)
51   return (%x, %y, %res))IR",
52       &*graph,
53       vmap);
54 
55   AT_ASSERT(graph->inputs().size() == 2);
56   AT_ASSERT(graph->outputs().size() == 3);
57   Value* x = graph->outputs()[0];
58   Value* y = graph->outputs()[1];
59   Value* res = graph->outputs()[2];
60   Value* t0 = graph->inputs()[0];
61   Value* t1 = graph->inputs()[1];
62   AT_ASSERT(vmap["x"] == x);
63   AT_ASSERT(vmap["y"] == y);
64   AT_ASSERT(vmap["res"] == res);
65   AT_ASSERT(vmap["0"] == t0);
66   AT_ASSERT(vmap["1"] == t1);
67   AT_ASSERT(x->node() == y->node());
68   Node* comb = x->node();
69   Value* t2 = comb->inputs()[1];
70   Value* t3 = comb->inputs()[2];
71   AT_ASSERT(vmap["2"] == t2);
72   AT_ASSERT(vmap["3"] == t3);
73   AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
74   AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
75   AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
76   Node* mul = res->node();
77   AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul"));
78   AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2}));
79   AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3}));
80   Node* add = t2->node();
81   AT_ASSERT(add->kind().toQualString() == std::string("foo::add"));
82   AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1}));
83   AT_ASSERT(add->outputs() == std::vector<Value*>({t2}));
84 }
85 
TEST(IRParserTest,NestedBlock)86 TEST(IRParserTest, NestedBlock) {
87   checkRoundtrip(R"IR(
88 graph():
89   %0 : Tensor = a::a()
90     block0():
91       %1 : Tensor = b::b()
92         block0():
93           %2 : Tensor = c::c()
94           -> ()
95       -> ()
96   %3 : Tensor = d::d()
97   return (%3)
98 )IR");
99 }
100 
TEST(IRParserTest,If)101 TEST(IRParserTest, If) {
102   checkRoundtrip(R"IR(
103 graph(%0 : Tensor,
104       %1 : Tensor,
105       %2 : Tensor):
106   %3 : int = prim::Constant[value=1]()
107   %4 : Tensor = aten::add(%0, %1, %3)
108   %5 : Tensor = prim::If(%2)
109     block0():
110       %6 : int = prim::Constant[value=1]()
111       %7 : Tensor = aten::add(%1, %3, %6)
112       %8 : int = prim::Constant[value=1]()
113       %9 : Tensor = aten::add(%7, %3, %8)
114       -> (%9)
115   %10 : int = prim::Constant[value=1]()
116   %11 : Tensor = aten::add(%5, %3, %10)
117   return (%11)
118 )IR");
119 }
120 
TEST(IRParserTest,If2)121 TEST(IRParserTest, If2) {
122   checkRoundtrip(R"IR(
123 graph(%0 : Tensor,
124       %1 : Tensor,
125       %2 : Tensor):
126   %3 : int = prim::Constant[value=-1]()
127   %4 : Tensor = aten::add(%0, %1, %3)
128   %5 : Tensor = prim::If(%2)
129     block0():
130       %6 : int = prim::Constant[value=1]()
131       %7 : Tensor = aten::add(%1, %3, %6)
132       %8 : int = prim::Constant[value=1]()
133       %9 : Tensor = aten::add(%7, %3, %8)
134       -> (%9)
135   %10 : int = prim::Constant[value=-987]()
136   %11 : Tensor = aten::add(%5, %3, %10)
137   return (%11)
138 )IR");
139 }
140 
TEST(IRParserTest,InferredTypeIsTensor)141 TEST(IRParserTest, InferredTypeIsTensor) {
142   auto graph = std::make_shared<Graph>();
143   parseIR(
144       R"IR(
145 graph(%a):
146   return (%a))IR",
147       &*graph);
148   AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
149 }
150 
TEST(IRParserTest,ValueReuse)151 TEST(IRParserTest, ValueReuse) {
152   // Check that parser correctly handles values reusing the same name.
153   auto graph = std::make_shared<Graph>();
154   parseIR(
155       R"IR(
156 graph(%x):
157   %x = a::a(%x)
158   %x = b::b(%x)
159   return (%x))IR",
160       &*graph);
161   Value* x0 = graph->inputs()[0];
162   Value* x2 = graph->outputs()[0];
163   Node* b = x2->node();
164   Value* x1 = b->inputs()[0];
165   Node* a = x1->node();
166   AT_ASSERT(a->inputs() == std::vector<Value*>({x0}));
167   AT_ASSERT(a->outputs() == std::vector<Value*>({x1}));
168   AT_ASSERT(b->inputs() == std::vector<Value*>({x1}));
169   AT_ASSERT(b->outputs() == std::vector<Value*>({x2}));
170 }
171 
TEST(IRParserTest,Attributes)172 TEST(IRParserTest, Attributes) {
173   // Check that parser handles attributes and types.
174   checkRoundtrip(
175       R"IR(
176 graph(%0 : Tensor,
177       %1 : Tensor,
178       %2 : Tensor):
179   %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3., s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0)
180   %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3., s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0)
181   %7 : float = vvv::vvv[s_asdf="hello"](%0)
182   %8 : string = z::z()
183   return (%7)
184 )IR");
185 }
186 
187 TEST(IRParserTest, OptionalTypes) {
188   checkRoundtrip(
189       R"IR(
190 graph(%0 : Tensor,
191       %1 : Tensor,
192       %2 : Tensor):
193   %3 : int? = prim::Constant()
194   return (%3)
195 )IR");
196 }
197 
198 TEST(IRParserTest, StarTensor) {
199   checkRoundtrip(
200       R"IR(
201 graph(%0 : Tensor,
202       %1 : Tensor,
203       %2 : Tensor):
204   %3 : Float(*, *, *) = prim::Constant()
205   return (%3)
206 )IR");
207 }
208 
209 TEST(IRParserTest, UnshapedTensor) {
210   checkRoundtrip(
211       R"IR(
212 graph(%0 : Tensor,
213       %1 : Tensor,
214       %2 : Tensor):
215   %3 : Long() = prim::Constant()
216   return (%3)
217 )IR");
218 }
219 
220 TEST(IRParserTest, ShapedTensor) {
221   checkRoundtrip(
222       R"IR(
223 graph(%0 : Tensor,
224       %1 : Tensor,
225       %2 : Tensor):
226   %3 : Double(4, 4, 5) = prim::Constant()
227   return (%3)
228 )IR");
229 }
230 
231 TEST(IRParserTest, NestedContrainer) {
232   checkRoundtrip(
233       R"IR(
234 graph():
235   %0 : float[] = prim::Constant[value=[1., 2., 3.]]()
236   %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]()
237   %2 : (float[], str[]) = prim::TupleConstruct(%0, %1)
238   return (%2)
239 )IR");
240 }
241 
242 TEST(IRParserTest, MalformedShapeAnnotation) {
243   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
244   EXPECT_ANY_THROW(checkRoundtrip(
245       R"IR(
246 graph(%0 : Tensor,
247     %1 : Tensor,
248     %2 : Tensor):
249   %3 : Double(4!, 4, 5) = prim::Constant()
250   return (%3)
251 )IR"));
252 }
253 
254 TEST(IRParserTest, FileCheck) {
255   auto graph = std::make_shared<Graph>();
256   const std::string& text =
257       R"IR(
258     graph(%a):
259     # CHECK: return
260       return (%a))IR";
261 
262   parseIR(text, &*graph);
263   AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
264   torch::jit::testing::FileCheck().run(text, *graph);
265 }
266 
267 TEST(IRParserTest, Strides) {
268   auto graph = std::make_shared<Graph>();
269   std::unordered_map<std::string, Value*> vmap;
270   parseIR(
271       R"IR(
272 graph(%a : Float(4, 5),
273       %b : Float(4, 5, strides=[5, 1]),
274       %c : Double(*, *)):
275   return (%a)
276 )IR",
277       &*graph,
278       vmap);
279   Value* a = graph->inputs()[0];
280   Value* b = graph->inputs()[1];
281   Value* c = graph->inputs()[2];
282 
283   auto a_type = a->type()->cast<TensorType>();
284   auto a_sizes = *a_type->sizes().concrete_sizes();
285   auto a_strides = a_type->strides().concrete_sizes();
286   AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5);
287   AT_ASSERT(a_strides == std::nullopt);
288 
289   auto b_type = b->type()->cast<TensorType>();
290   auto b_sizes = *b_type->sizes().concrete_sizes();
291   auto b_strides = *(b_type->strides().sizes());
292   AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5);
293   AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1);
294 
295   auto c_type = c->type()->cast<TensorType>();
296   AT_ASSERT(*c_type->sizes().size() == 2);
297   AT_ASSERT(c_type->sizes().concrete_sizes() == std::nullopt);
298   AT_ASSERT(c_type->strides().concrete_sizes() == std::nullopt);
299 }
300 
301 TEST(IRParserTest, MalformedStrides) {
302   auto graph = std::make_shared<Graph>();
303   std::unordered_map<std::string, Value*> vmap;
304   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
305   EXPECT_ANY_THROW(parseIR(
306       R"IR(
307 graph(%a : Float(4, strides=[5], 5)):
308   return (%a)
309 )IR",
310       &*graph,
311       vmap));
312 }
313 
314 TEST(IRParserTest, TensorShapes) {
315   checkRoundtrip(
316       R"IR(
317 graph(%a : Float(4, 5),
318       %b : Float(4, 5, strides=[5, 1]),
319       %c : Double(*, *)):
320   return (%a)
321 )IR");
322 }
323 
324 TEST(IRParserTest, DeviceAndRequiresGradTensors) {
325   checkRoundtrip(
326       R"IR(
327 graph(%a : Float(*, *, device=cpu),
328       %b : Float(*, *, requires_grad=1),
329       %c : Long(5, 10, requires_grad=1, device=cpu),
330       %d : Float(5, requires_grad=0, device=cuda:2),
331       %e : Long(4, 3, 1, strides=[6, 2, 1], requires_grad=0, device=cuda:1),
332       %f : Float(),
333       %g : Float(device=cpu),
334       %h : Float(requires_grad=1),
335       %i : Float(requires_grad=0, device=cuda:1),
336       %j : Double(*, *, requires_grad=0)):
337   return (%a)
338 )IR");
339 }
340 
341 TEST(IRParserTest, ListConstant) {
342   auto graph = std::make_shared<Graph>();
343   parseIR(
344       R"IR(
345 graph():
346   %d : int[] = prim::Constant[value=[1,2,3]]()
347   return (%d)
348 )IR",
349       &*graph);
350   Node* n = graph->outputs()[0]->node();
351   AT_ASSERT(n->kind() == prim::Constant);
352   AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival);
353   const auto& genericList = n->ival(attr::value).toList();
354   std::vector<int> int_vals;
355   // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
356   for (const IValue& ival : genericList) {
357     int_vals.push_back(ival.toInt());
358   }
359   AT_ASSERT(int_vals.size() == 3);
360   AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3);
361 }
362 
363 TEST(IRParserTest, PartialStarTensor) {
364   checkRoundtrip(
365       R"IR(
366 graph(%x : Float(10, *, 10)):
367   return (%x)
368 )IR");
369 }
370 
371 TEST(IRParserTest, ComplexTensorAttributes) {
372   checkRoundtrip(
373       R"IR(
374 graph(%x : Double(*, 200, *, requires_grad=1, device=cuda:1),
375       %b : Float(5, *, requires_grad=1),
376       %c : Long(*, 10, device=cpu)):
377   return (%x)
378 )IR");
379 }
380 } // namespace jit
381 } // namespace torch
382