1 #include <torch/csrc/jit/python/python_tree_views.h>
2
3 #include <torch/csrc/jit/frontend/tree_views.h>
4
5 #include <pybind11/pybind11.h>
6 #include <pybind11/stl.h>
7 #include <torch/csrc/utils/pybind.h>
8
9 #include <sstream>
10
11 namespace py = pybind11;
12
13 namespace torch::jit {
14
maybeConvertToString(const py::object & obj)15 std::optional<std::string> maybeConvertToString(const py::object& obj) {
16 if (obj.is_none()) {
17 return std::nullopt;
18 }
19 std::stringstream ss;
20 ss << py::str(obj);
21 return ss.str();
22 }
23
24 struct SourceRangeFactory {
SourceRangeFactorytorch::jit::SourceRangeFactory25 SourceRangeFactory(
26 const std::string& text,
27 const py::object& filename,
28 size_t file_lineno,
29 size_t leading_whitespace_chars)
30 : source_(std::make_shared<Source>(
31 text,
32 maybeConvertToString(filename),
33 file_lineno)),
34 leading_whitespace_chars_(leading_whitespace_chars) {}
35
createtorch::jit::SourceRangeFactory36 SourceRange create(int line, int start_col, int end_col) {
37 auto [start_byte_offset, end_byte_offset] = line_col_to_byte_offs(
38 line,
39 start_col + leading_whitespace_chars_,
40 end_col + leading_whitespace_chars_);
41 return SourceRange(source_, start_byte_offset, end_byte_offset);
42 }
43
line_col_to_byte_offstorch::jit::SourceRangeFactory44 std::tuple<size_t, size_t> line_col_to_byte_offs(
45 int line,
46 size_t start_col,
47 size_t end_col) {
48 // lines are counted from 1.
49 line--;
50 auto line_start = source_->offset_for_line(line);
51 return std::make_tuple<size_t, size_t>(
52 line_start + start_col, line_start + end_col);
53 }
54
55 std::shared_ptr<Source> source_;
56 std::vector<size_t> line_len_prefix_sum_;
57 size_t leading_whitespace_chars_;
58 };
59
60 template <typename T>
wrap_list(const SourceRange & fallback_pos,std::vector<T> && vec)61 List<T> wrap_list(const SourceRange& fallback_pos, std::vector<T>&& vec) {
62 if (vec.empty())
63 return List<T>::create(fallback_pos, std::move(vec));
64 return List<T>::create(vec.front().range(), std::move(vec));
65 }
66
67 template <typename T>
wrap_maybe(const SourceRange & fallback_pos,T * val)68 Maybe<T> wrap_maybe(const SourceRange& fallback_pos, T* val) {
69 return val ? Maybe<T>::create(val->range(), *val)
70 : Maybe<T>::create(fallback_pos);
71 }
72
initTreeViewBindings(PyObject * module)73 void initTreeViewBindings(PyObject* module) {
74 auto _C = py::handle(module).cast<py::module>();
75 auto m = _C.def_submodule("_jit_tree_views");
76
77 py::class_<SourceRange>(m, "SourceRange")
78 .def(
79 "highlight",
80 [](const SourceRange& self) {
81 std::ostringstream stream;
82 self.highlight(stream);
83 return stream.str();
84 })
85 .def("__repr__", [](const SourceRange& self) { return self.str(); })
86 .def(
87 "__str__",
88 [](const SourceRange& self) {
89 return "SourceRange at:\n" + self.str();
90 })
91 .def_property_readonly("start", &SourceRange::start)
92 .def_property_readonly("end", &SourceRange::end);
93 py::class_<SourceRangeFactory>(m, "SourceRangeFactory")
94 .def(py::init<std::string&&, py::object, size_t, size_t>())
95 .def("make_range", &SourceRangeFactory::create)
96 .def(
97 "make_raw_range",
98 [](const SourceRangeFactory& self, size_t start, size_t end) {
99 return SourceRange(self.source_, start, end);
100 })
101 .def_property_readonly("source", [](const SourceRangeFactory& self) {
102 auto text_view = self.source_->text_str().str();
103 return text_view;
104 });
105
106 py::class_<TreeView>(m, "TreeView")
107 .def("range", &TreeView::range)
108 .def(
109 "__str__",
110 [](const TreeView& tree) {
111 std::ostringstream stream;
112 stream << tree.get();
113 return stream.str();
114 })
115 .def("dump", [](const TreeView& tree) { tree.dump(); });
116
117 py::class_<Ident, TreeView>(m, "Ident")
118 .def(py::init(&Ident::create))
119 .def_property_readonly(
120 "name", [](const Ident& self) { return self.name(); });
121
122 py::class_<Param, TreeView>(m, "Param")
123 .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) {
124 return Param::create(
125 name.range(),
126 name,
127 Maybe<Expr>::create(type.range(), type),
128 Maybe<Expr>::create(name.range()),
129 kwarg_only);
130 }))
131 .def(py::init(
132 [](const Maybe<Expr>& type, const Ident& name, bool kwarg_only) {
133 return Param::create(
134 name.range(),
135 name,
136 type,
137 Maybe<Expr>::create(name.range()),
138 kwarg_only);
139 }));
140 py::class_<Attribute, TreeView>(m, "Attribute")
141 .def(py::init([](const Ident& name, const Expr& value) {
142 return Attribute::create(name.range(), name, value);
143 }));
144 m.def("TrueLiteral", [](const SourceRange& range) {
145 return Expr(Compound::create(TK_TRUE, range, {}));
146 });
147 m.def("FalseLiteral", [](const SourceRange& range) {
148 return Expr(Compound::create(TK_FALSE, range, {}));
149 });
150 m.def("NoneLiteral", [](const SourceRange& range) {
151 return Expr(Compound::create(TK_NONE, range, {}));
152 });
153
154 py::class_<Stmt, TreeView>(m, "Stmt") // NOLINT(bugprone-unused-raii)
155 .def(py::init([](const TreeView& thing) { return Stmt(thing.get()); }));
156 py::class_<Expr, TreeView>(m, "Expr"); // NOLINT(bugprone-unused-raii)
157 py::class_<Def, TreeView>(m, "Def")
158 .def(py::init(
159 [](const Ident& name, const Decl& decl, std::vector<Stmt> body) {
160 const auto& r = name.range();
161 return Def::create(r, name, decl, wrap_list(r, std::move(body)));
162 }))
163 .def("decl", [](const Def& def) { return def.decl(); })
164 .def("name", [](const Def& def) { return def.name(); });
165 py::class_<Property, TreeView>(m, "Property")
166 .def(py::init([](const SourceRange& r,
167 const Ident& name,
168 const Def& getter,
169 Def* setter) {
170 return Property::create(r, name, getter, wrap_maybe(r, setter));
171 }))
172 .def("name", [](const Property& property) { return property.name(); })
173 .def(
174 "getter_name",
175 [](const Property& property) { return property.getter().name(); })
176 .def("setter_name", [](const Property& property) {
177 if (property.setter().present()) {
178 return std::optional<Ident>(property.setter().get().name());
179 }
180
181 return std::optional<Ident>(std::nullopt);
182 });
183
184 py::class_<ClassDef, TreeView>(m, "ClassDef")
185 .def(py::init([](const Ident& name,
186 std::vector<Stmt> body,
187 std::vector<Property> props,
188 std::vector<Assign> assigns) {
189 const auto& r = name.range();
190 return ClassDef::create(
191 r,
192 name,
193 Maybe<Expr>::create(r),
194 wrap_list(r, std::move(body)),
195 wrap_list(r, std::move(props)),
196 wrap_list(r, std::move(assigns)));
197 }));
198
199 py::class_<Decl, TreeView>(m, "Decl").def(py::init(
200 [](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
201 return Decl::create(
202 r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
203 }));
204
205 py::class_<Delete, Stmt>(m, "Delete")
206 .def(py::init([](const SourceRange& range, std::vector<Expr> targets) {
207 return Delete::create(range, wrap_list(range, std::move(targets)));
208 }));
209
210 py::class_<WithItem, Expr>(m, "WithItem")
211 .def(py::init([](const SourceRange& range, const Expr& target, Var* var) {
212 return WithItem::create(range, target, wrap_maybe(range, var));
213 }));
214
215 py::class_<Assign, Stmt>(m, "Assign")
216 .def(py::init([](std::vector<Expr> lhs, const Expr& rhs) {
217 auto li = wrap_list(rhs.range(), std::move(lhs));
218 return Assign::create(
219 li.range(),
220 li,
221 Maybe<Expr>::create(rhs.range(), rhs),
222 Maybe<Expr>::create(li.range()));
223 }))
224 .def(py::init([](std::vector<Expr> lhs, const Expr& rhs, Expr* type) {
225 auto li = wrap_list(rhs.range(), std::move(lhs));
226 return Assign::create(
227 li.range(),
228 li,
229 Maybe<Expr>::create(rhs.range(), rhs),
230 wrap_maybe(li.range(), type));
231 }));
232 py::class_<AugAssign, Stmt>(m, "AugAssign")
233 .def(py::init(
234 [](const Expr& lhs, const std::string& kind_str, const Expr& rhs) {
235 const auto& r = lhs.range();
236 auto kind =
237 AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
238 return AugAssign::create(r, lhs, kind, rhs);
239 }));
240 py::class_<Return, Stmt>(m, "Return")
241 .def(py::init([](const SourceRange& range, Expr* value) {
242 return Return::create(
243 range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
244 }));
245 py::class_<Raise, Stmt>(m, "Raise")
246 .def(py::init([](const SourceRange& range, const Expr& expr) {
247 return Raise::create(range, expr);
248 }));
249 py::class_<Assert, Stmt>(m, "Assert")
250 .def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) {
251 return Assert::create(range, test, wrap_maybe(range, msg));
252 }));
253 py::class_<Pass, Stmt>(m, "Pass").def(
254 py::init([](const SourceRange& range) { return Pass::create(range); }));
255 py::class_<Break, Stmt>(m, "Break")
256 .def(py::init(
257 [](const SourceRange& range) { return Break::create(range); }));
258 py::class_<Continue, Stmt>(m, "Continue")
259 .def(py::init(
260 [](const SourceRange& range) { return Continue::create(range); }));
261 py::class_<Dots, Expr>(m, "Dots").def(
262 py::init([](const SourceRange& range) { return Dots::create(range); }));
263 py::class_<If, Stmt>(m, "If").def(
264 py::init([](const SourceRange& range,
265 const Expr& cond,
266 std::vector<Stmt> true_branch,
267 std::vector<Stmt> false_branch) {
268 return If::create(
269 range,
270 cond,
271 wrap_list(range, std::move(true_branch)),
272 wrap_list(range, std::move(false_branch)));
273 }));
274 py::class_<While, Stmt>(m, "While")
275 .def(py::init([](const SourceRange& range,
276 const Expr& cond,
277 std::vector<Stmt> body) {
278 return While::create(range, cond, wrap_list(range, std::move(body)));
279 }));
280 py::class_<With, Stmt>(m, "With").def(
281 py::init([](const SourceRange& range,
282 std::vector<WithItem> targets,
283 std::vector<Stmt> body) {
284 return With::create(
285 range,
286 wrap_list(range, std::move(targets)),
287 wrap_list(range, std::move(body)));
288 }));
289 py::class_<For, Stmt>(m, "For").def(py::init([](const SourceRange& range,
290 std::vector<Expr>& targets,
291 std::vector<Expr>& itrs,
292 std::vector<Stmt> body) {
293 return For::create(
294 range,
295 wrap_list(range, std::move(targets)),
296 wrap_list(range, std::move(itrs)),
297 wrap_list(range, std::move(body)));
298 }));
299 py::class_<ExprStmt, Stmt>(m, "ExprStmt").def(py::init([](const Expr& expr) {
300 return ExprStmt::create(expr.range(), expr);
301 }));
302
303 py::class_<Var, Expr>(m, "Var")
304 .def(py::init(
305 [](const Ident& name) { return Var::create(name.range(), name); }))
306 .def_property_readonly("name", [](const Var& var) { return var.name(); });
307 py::class_<BinOp, Expr>(m, "BinOp")
308 .def(py::init(
309 [](const std::string& kind, const Expr& lhs, const Expr& rhs) {
310 return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
311 }));
312 // NB: we take range here, because unary ops precede their exprs, so we need
313 // to include them
314 py::class_<UnaryOp, Expr>(m, "UnaryOp")
315 .def(py::init([](const SourceRange& range,
316 const std::string& kind,
317 const Expr& expr) {
318 auto resolved_kind = stringToKind(kind);
319 resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
320 return UnaryOp::create(range, resolved_kind, expr);
321 }));
322 py::class_<Const, Expr>(m, "Const")
323 .def(py::init([](const SourceRange& range, const std::string& value) {
324 return Const::create(range, value);
325 }));
326 py::class_<StringLiteral, Expr>(m, "StringLiteral")
327 .def(py::init([](const SourceRange& range, const std::string& value) {
328 return StringLiteral::create(range, value);
329 }));
330 py::class_<Apply, Expr>(m, "Apply")
331 .def(py::init([](const Expr& expr,
332 std::vector<Expr> args,
333 std::vector<Attribute> kwargs) {
334 const auto& r = expr.range();
335 return Apply::create(
336 expr.range(),
337 expr,
338 wrap_list(r, std::move(args)),
339 wrap_list(r, std::move(kwargs)));
340 }));
341 py::class_<Select, Expr>(m, "Select")
342 .def(py::init([](const Expr& expr, const Ident& field) {
343 return Select::create(expr.range(), expr, field);
344 }));
345 py::class_<TernaryIf, Expr>(m, "TernaryIf")
346 .def(py::init(
347 [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
348 return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
349 }));
350 py::class_<ListComp, Expr>(m, "ListComp")
351 .def(py::init([](const SourceRange& range,
352 const Expr& elt,
353 const Expr& target,
354 const Expr& iter) {
355 return ListComp::create(range, elt, target, iter);
356 }));
357 py::class_<DictComp, Expr>(m, "DictComp")
358 .def(py::init([](const SourceRange& range,
359 const Expr& key,
360 const Expr& value,
361 const Expr& target,
362 const Expr& iter) {
363 return DictComp::create(range, key, value, target, iter);
364 }));
365 py::class_<ListLiteral, Expr>(m, "ListLiteral")
366 .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
367 return ListLiteral::create(range, wrap_list(range, std::move(args)));
368 }));
369 py::class_<TupleLiteral, Expr>(m, "TupleLiteral")
370 .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
371 return TupleLiteral::create(range, wrap_list(range, std::move(args)));
372 }));
373 py::class_<DictLiteral, Expr>(m, "DictLiteral")
374 .def(py::init([](const SourceRange& range,
375 std::vector<Expr> keys,
376 std::vector<Expr> values) {
377 return DictLiteral::create(
378 range,
379 wrap_list(range, std::move(keys)),
380 wrap_list(range, std::move(values)));
381 }));
382 py::class_<Subscript, Expr>(m, "Subscript")
383 .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
384 return Subscript::create(
385 base.range(),
386 base,
387 wrap_list(base.range(), std::move(subscript_exprs)));
388 }));
389 py::class_<SliceExpr, Expr>(m, "SliceExpr")
390 .def(py::init(
391 [](const SourceRange& range, Expr* lower, Expr* upper, Expr* step) {
392 return SliceExpr::create(
393 range,
394 wrap_maybe(range, lower),
395 wrap_maybe(range, upper),
396 wrap_maybe(range, step));
397 }));
398 py::class_<Starred, Expr>(m, "Starred")
399 .def(py::init([](const SourceRange& range, const Expr& expr) {
400 return Starred::create(range, expr);
401 }));
402 py::class_<Maybe<Expr>, TreeView>(m, "EmptyTypeAnnotation")
403 .def(py::init(
404 [](const SourceRange& range) { return Maybe<Expr>::create(range); }));
405 }
406
407 } // namespace torch::jit
408