1 #include <gtest/gtest.h>
2
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/irparser.h>
5 #include <torch/csrc/jit/testing/file_check.h>
6
7 #include <sstream>
8 #include <string>
9
10 namespace torch {
11 namespace jit {
12
13 /** \brief Parse IR from \p S, print the parsed graph and verify that the output
14 * string matches the original string.
15 *
16 * The function is sensitive to value naming and whitespace, so it should be
17 * used with care. Nevertheless, it helps to keep tests more compact.
18 */
checkRoundtrip(const std::string & s)19 static void checkRoundtrip(const std::string& s) {
20 auto graph = std::make_shared<Graph>();
21 parseIR(s, &*graph);
22 std::ostringstream ss;
23 ss << *graph;
24 std::string parsed = ss.str();
25
26 // Skip whitespace in the beginning of the input string.
27 int i = 0;
28 for (char c : s) {
29 if (!isspace(c)) {
30 break;
31 }
32 i++;
33 }
34 std::string original = s.substr(i, s.size());
35 if (original != parsed) {
36 std::cerr << "Input:" << std::endl << original << std::endl;
37 std::cerr << "Parsed:" << std::endl << parsed << std::endl;
38 }
39 AT_ASSERT(original == parsed);
40 }
41
TEST(IRParserTest,Basic)42 TEST(IRParserTest, Basic) {
43 auto graph = std::make_shared<Graph>();
44 std::unordered_map<std::string, Value*> vmap;
45 parseIR(
46 R"IR(
47 graph(%0 : Tensor, %1 : Tensor):
48 %2 : Tensor = foo::add(%0, %1)
49 %res, %3 = foo::mul(%0, %2)
50 %x, %y = foo::combine(%res, %2, %3)
51 return (%x, %y, %res))IR",
52 &*graph,
53 vmap);
54
55 AT_ASSERT(graph->inputs().size() == 2);
56 AT_ASSERT(graph->outputs().size() == 3);
57 Value* x = graph->outputs()[0];
58 Value* y = graph->outputs()[1];
59 Value* res = graph->outputs()[2];
60 Value* t0 = graph->inputs()[0];
61 Value* t1 = graph->inputs()[1];
62 AT_ASSERT(vmap["x"] == x);
63 AT_ASSERT(vmap["y"] == y);
64 AT_ASSERT(vmap["res"] == res);
65 AT_ASSERT(vmap["0"] == t0);
66 AT_ASSERT(vmap["1"] == t1);
67 AT_ASSERT(x->node() == y->node());
68 Node* comb = x->node();
69 Value* t2 = comb->inputs()[1];
70 Value* t3 = comb->inputs()[2];
71 AT_ASSERT(vmap["2"] == t2);
72 AT_ASSERT(vmap["3"] == t3);
73 AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
74 AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
75 AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
76 Node* mul = res->node();
77 AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul"));
78 AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2}));
79 AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3}));
80 Node* add = t2->node();
81 AT_ASSERT(add->kind().toQualString() == std::string("foo::add"));
82 AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1}));
83 AT_ASSERT(add->outputs() == std::vector<Value*>({t2}));
84 }
85
TEST(IRParserTest,NestedBlock)86 TEST(IRParserTest, NestedBlock) {
87 checkRoundtrip(R"IR(
88 graph():
89 %0 : Tensor = a::a()
90 block0():
91 %1 : Tensor = b::b()
92 block0():
93 %2 : Tensor = c::c()
94 -> ()
95 -> ()
96 %3 : Tensor = d::d()
97 return (%3)
98 )IR");
99 }
100
TEST(IRParserTest,If)101 TEST(IRParserTest, If) {
102 checkRoundtrip(R"IR(
103 graph(%0 : Tensor,
104 %1 : Tensor,
105 %2 : Tensor):
106 %3 : int = prim::Constant[value=1]()
107 %4 : Tensor = aten::add(%0, %1, %3)
108 %5 : Tensor = prim::If(%2)
109 block0():
110 %6 : int = prim::Constant[value=1]()
111 %7 : Tensor = aten::add(%1, %3, %6)
112 %8 : int = prim::Constant[value=1]()
113 %9 : Tensor = aten::add(%7, %3, %8)
114 -> (%9)
115 %10 : int = prim::Constant[value=1]()
116 %11 : Tensor = aten::add(%5, %3, %10)
117 return (%11)
118 )IR");
119 }
120
TEST(IRParserTest,If2)121 TEST(IRParserTest, If2) {
122 checkRoundtrip(R"IR(
123 graph(%0 : Tensor,
124 %1 : Tensor,
125 %2 : Tensor):
126 %3 : int = prim::Constant[value=-1]()
127 %4 : Tensor = aten::add(%0, %1, %3)
128 %5 : Tensor = prim::If(%2)
129 block0():
130 %6 : int = prim::Constant[value=1]()
131 %7 : Tensor = aten::add(%1, %3, %6)
132 %8 : int = prim::Constant[value=1]()
133 %9 : Tensor = aten::add(%7, %3, %8)
134 -> (%9)
135 %10 : int = prim::Constant[value=-987]()
136 %11 : Tensor = aten::add(%5, %3, %10)
137 return (%11)
138 )IR");
139 }
140
TEST(IRParserTest,InferredTypeIsTensor)141 TEST(IRParserTest, InferredTypeIsTensor) {
142 auto graph = std::make_shared<Graph>();
143 parseIR(
144 R"IR(
145 graph(%a):
146 return (%a))IR",
147 &*graph);
148 AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
149 }
150
TEST(IRParserTest,ValueReuse)151 TEST(IRParserTest, ValueReuse) {
152 // Check that parser correctly handles values reusing the same name.
153 auto graph = std::make_shared<Graph>();
154 parseIR(
155 R"IR(
156 graph(%x):
157 %x = a::a(%x)
158 %x = b::b(%x)
159 return (%x))IR",
160 &*graph);
161 Value* x0 = graph->inputs()[0];
162 Value* x2 = graph->outputs()[0];
163 Node* b = x2->node();
164 Value* x1 = b->inputs()[0];
165 Node* a = x1->node();
166 AT_ASSERT(a->inputs() == std::vector<Value*>({x0}));
167 AT_ASSERT(a->outputs() == std::vector<Value*>({x1}));
168 AT_ASSERT(b->inputs() == std::vector<Value*>({x1}));
169 AT_ASSERT(b->outputs() == std::vector<Value*>({x2}));
170 }
171
TEST(IRParserTest,Attributes)172 TEST(IRParserTest, Attributes) {
173 // Check that parser handles attributes and types.
174 checkRoundtrip(
175 R"IR(
176 graph(%0 : Tensor,
177 %1 : Tensor,
178 %2 : Tensor):
179 %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3., s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0)
180 %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3., s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0)
181 %7 : float = vvv::vvv[s_asdf="hello"](%0)
182 %8 : string = z::z()
183 return (%7)
184 )IR");
185 }
186
187 TEST(IRParserTest, OptionalTypes) {
188 checkRoundtrip(
189 R"IR(
190 graph(%0 : Tensor,
191 %1 : Tensor,
192 %2 : Tensor):
193 %3 : int? = prim::Constant()
194 return (%3)
195 )IR");
196 }
197
198 TEST(IRParserTest, StarTensor) {
199 checkRoundtrip(
200 R"IR(
201 graph(%0 : Tensor,
202 %1 : Tensor,
203 %2 : Tensor):
204 %3 : Float(*, *, *) = prim::Constant()
205 return (%3)
206 )IR");
207 }
208
209 TEST(IRParserTest, UnshapedTensor) {
210 checkRoundtrip(
211 R"IR(
212 graph(%0 : Tensor,
213 %1 : Tensor,
214 %2 : Tensor):
215 %3 : Long() = prim::Constant()
216 return (%3)
217 )IR");
218 }
219
220 TEST(IRParserTest, ShapedTensor) {
221 checkRoundtrip(
222 R"IR(
223 graph(%0 : Tensor,
224 %1 : Tensor,
225 %2 : Tensor):
226 %3 : Double(4, 4, 5) = prim::Constant()
227 return (%3)
228 )IR");
229 }
230
231 TEST(IRParserTest, NestedContrainer) {
232 checkRoundtrip(
233 R"IR(
234 graph():
235 %0 : float[] = prim::Constant[value=[1., 2., 3.]]()
236 %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]()
237 %2 : (float[], str[]) = prim::TupleConstruct(%0, %1)
238 return (%2)
239 )IR");
240 }
241
242 TEST(IRParserTest, MalformedShapeAnnotation) {
243 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
244 EXPECT_ANY_THROW(checkRoundtrip(
245 R"IR(
246 graph(%0 : Tensor,
247 %1 : Tensor,
248 %2 : Tensor):
249 %3 : Double(4!, 4, 5) = prim::Constant()
250 return (%3)
251 )IR"));
252 }
253
254 TEST(IRParserTest, FileCheck) {
255 auto graph = std::make_shared<Graph>();
256 const std::string& text =
257 R"IR(
258 graph(%a):
259 # CHECK: return
260 return (%a))IR";
261
262 parseIR(text, &*graph);
263 AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
264 torch::jit::testing::FileCheck().run(text, *graph);
265 }
266
267 TEST(IRParserTest, Strides) {
268 auto graph = std::make_shared<Graph>();
269 std::unordered_map<std::string, Value*> vmap;
270 parseIR(
271 R"IR(
272 graph(%a : Float(4, 5),
273 %b : Float(4, 5, strides=[5, 1]),
274 %c : Double(*, *)):
275 return (%a)
276 )IR",
277 &*graph,
278 vmap);
279 Value* a = graph->inputs()[0];
280 Value* b = graph->inputs()[1];
281 Value* c = graph->inputs()[2];
282
283 auto a_type = a->type()->cast<TensorType>();
284 auto a_sizes = *a_type->sizes().concrete_sizes();
285 auto a_strides = a_type->strides().concrete_sizes();
286 AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5);
287 AT_ASSERT(a_strides == std::nullopt);
288
289 auto b_type = b->type()->cast<TensorType>();
290 auto b_sizes = *b_type->sizes().concrete_sizes();
291 auto b_strides = *(b_type->strides().sizes());
292 AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5);
293 AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1);
294
295 auto c_type = c->type()->cast<TensorType>();
296 AT_ASSERT(*c_type->sizes().size() == 2);
297 AT_ASSERT(c_type->sizes().concrete_sizes() == std::nullopt);
298 AT_ASSERT(c_type->strides().concrete_sizes() == std::nullopt);
299 }
300
301 TEST(IRParserTest, MalformedStrides) {
302 auto graph = std::make_shared<Graph>();
303 std::unordered_map<std::string, Value*> vmap;
304 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
305 EXPECT_ANY_THROW(parseIR(
306 R"IR(
307 graph(%a : Float(4, strides=[5], 5)):
308 return (%a)
309 )IR",
310 &*graph,
311 vmap));
312 }
313
314 TEST(IRParserTest, TensorShapes) {
315 checkRoundtrip(
316 R"IR(
317 graph(%a : Float(4, 5),
318 %b : Float(4, 5, strides=[5, 1]),
319 %c : Double(*, *)):
320 return (%a)
321 )IR");
322 }
323
324 TEST(IRParserTest, DeviceAndRequiresGradTensors) {
325 checkRoundtrip(
326 R"IR(
327 graph(%a : Float(*, *, device=cpu),
328 %b : Float(*, *, requires_grad=1),
329 %c : Long(5, 10, requires_grad=1, device=cpu),
330 %d : Float(5, requires_grad=0, device=cuda:2),
331 %e : Long(4, 3, 1, strides=[6, 2, 1], requires_grad=0, device=cuda:1),
332 %f : Float(),
333 %g : Float(device=cpu),
334 %h : Float(requires_grad=1),
335 %i : Float(requires_grad=0, device=cuda:1),
336 %j : Double(*, *, requires_grad=0)):
337 return (%a)
338 )IR");
339 }
340
341 TEST(IRParserTest, ListConstant) {
342 auto graph = std::make_shared<Graph>();
343 parseIR(
344 R"IR(
345 graph():
346 %d : int[] = prim::Constant[value=[1,2,3]]()
347 return (%d)
348 )IR",
349 &*graph);
350 Node* n = graph->outputs()[0]->node();
351 AT_ASSERT(n->kind() == prim::Constant);
352 AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival);
353 const auto& genericList = n->ival(attr::value).toList();
354 std::vector<int> int_vals;
355 // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
356 for (const IValue& ival : genericList) {
357 int_vals.push_back(ival.toInt());
358 }
359 AT_ASSERT(int_vals.size() == 3);
360 AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3);
361 }
362
363 TEST(IRParserTest, PartialStarTensor) {
364 checkRoundtrip(
365 R"IR(
366 graph(%x : Float(10, *, 10)):
367 return (%x)
368 )IR");
369 }
370
371 TEST(IRParserTest, ComplexTensorAttributes) {
372 checkRoundtrip(
373 R"IR(
374 graph(%x : Double(*, 200, *, requires_grad=1, device=cuda:1),
375 %b : Float(5, *, requires_grad=1),
376 %c : Long(*, 10, device=cpu)):
377 return (%x)
378 )IR");
379 }
380 } // namespace jit
381 } // namespace torch
382