xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_tree_views.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/python/python_tree_views.h>
2 
3 #include <torch/csrc/jit/frontend/tree_views.h>
4 
5 #include <pybind11/pybind11.h>
6 #include <pybind11/stl.h>
7 #include <torch/csrc/utils/pybind.h>
8 
9 #include <sstream>
10 
11 namespace py = pybind11;
12 
13 namespace torch::jit {
14 
maybeConvertToString(const py::object & obj)15 std::optional<std::string> maybeConvertToString(const py::object& obj) {
16   if (obj.is_none()) {
17     return std::nullopt;
18   }
19   std::stringstream ss;
20   ss << py::str(obj);
21   return ss.str();
22 }
23 
24 struct SourceRangeFactory {
SourceRangeFactorytorch::jit::SourceRangeFactory25   SourceRangeFactory(
26       const std::string& text,
27       const py::object& filename,
28       size_t file_lineno,
29       size_t leading_whitespace_chars)
30       : source_(std::make_shared<Source>(
31             text,
32             maybeConvertToString(filename),
33             file_lineno)),
34         leading_whitespace_chars_(leading_whitespace_chars) {}
35 
createtorch::jit::SourceRangeFactory36   SourceRange create(int line, int start_col, int end_col) {
37     auto [start_byte_offset, end_byte_offset] = line_col_to_byte_offs(
38         line,
39         start_col + leading_whitespace_chars_,
40         end_col + leading_whitespace_chars_);
41     return SourceRange(source_, start_byte_offset, end_byte_offset);
42   }
43 
line_col_to_byte_offstorch::jit::SourceRangeFactory44   std::tuple<size_t, size_t> line_col_to_byte_offs(
45       int line,
46       size_t start_col,
47       size_t end_col) {
48     // lines are counted from 1.
49     line--;
50     auto line_start = source_->offset_for_line(line);
51     return std::make_tuple<size_t, size_t>(
52         line_start + start_col, line_start + end_col);
53   }
54 
55   std::shared_ptr<Source> source_;
56   std::vector<size_t> line_len_prefix_sum_;
57   size_t leading_whitespace_chars_;
58 };
59 
60 template <typename T>
wrap_list(const SourceRange & fallback_pos,std::vector<T> && vec)61 List<T> wrap_list(const SourceRange& fallback_pos, std::vector<T>&& vec) {
62   if (vec.empty())
63     return List<T>::create(fallback_pos, std::move(vec));
64   return List<T>::create(vec.front().range(), std::move(vec));
65 }
66 
67 template <typename T>
wrap_maybe(const SourceRange & fallback_pos,T * val)68 Maybe<T> wrap_maybe(const SourceRange& fallback_pos, T* val) {
69   return val ? Maybe<T>::create(val->range(), *val)
70              : Maybe<T>::create(fallback_pos);
71 }
72 
initTreeViewBindings(PyObject * module)73 void initTreeViewBindings(PyObject* module) {
74   auto _C = py::handle(module).cast<py::module>();
75   auto m = _C.def_submodule("_jit_tree_views");
76 
77   py::class_<SourceRange>(m, "SourceRange")
78       .def(
79           "highlight",
80           [](const SourceRange& self) {
81             std::ostringstream stream;
82             self.highlight(stream);
83             return stream.str();
84           })
85       .def("__repr__", [](const SourceRange& self) { return self.str(); })
86       .def(
87           "__str__",
88           [](const SourceRange& self) {
89             return "SourceRange at:\n" + self.str();
90           })
91       .def_property_readonly("start", &SourceRange::start)
92       .def_property_readonly("end", &SourceRange::end);
93   py::class_<SourceRangeFactory>(m, "SourceRangeFactory")
94       .def(py::init<std::string&&, py::object, size_t, size_t>())
95       .def("make_range", &SourceRangeFactory::create)
96       .def(
97           "make_raw_range",
98           [](const SourceRangeFactory& self, size_t start, size_t end) {
99             return SourceRange(self.source_, start, end);
100           })
101       .def_property_readonly("source", [](const SourceRangeFactory& self) {
102         auto text_view = self.source_->text_str().str();
103         return text_view;
104       });
105 
106   py::class_<TreeView>(m, "TreeView")
107       .def("range", &TreeView::range)
108       .def(
109           "__str__",
110           [](const TreeView& tree) {
111             std::ostringstream stream;
112             stream << tree.get();
113             return stream.str();
114           })
115       .def("dump", [](const TreeView& tree) { tree.dump(); });
116 
117   py::class_<Ident, TreeView>(m, "Ident")
118       .def(py::init(&Ident::create))
119       .def_property_readonly(
120           "name", [](const Ident& self) { return self.name(); });
121 
122   py::class_<Param, TreeView>(m, "Param")
123       .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) {
124         return Param::create(
125             name.range(),
126             name,
127             Maybe<Expr>::create(type.range(), type),
128             Maybe<Expr>::create(name.range()),
129             kwarg_only);
130       }))
131       .def(py::init(
132           [](const Maybe<Expr>& type, const Ident& name, bool kwarg_only) {
133             return Param::create(
134                 name.range(),
135                 name,
136                 type,
137                 Maybe<Expr>::create(name.range()),
138                 kwarg_only);
139           }));
140   py::class_<Attribute, TreeView>(m, "Attribute")
141       .def(py::init([](const Ident& name, const Expr& value) {
142         return Attribute::create(name.range(), name, value);
143       }));
144   m.def("TrueLiteral", [](const SourceRange& range) {
145     return Expr(Compound::create(TK_TRUE, range, {}));
146   });
147   m.def("FalseLiteral", [](const SourceRange& range) {
148     return Expr(Compound::create(TK_FALSE, range, {}));
149   });
150   m.def("NoneLiteral", [](const SourceRange& range) {
151     return Expr(Compound::create(TK_NONE, range, {}));
152   });
153 
154   py::class_<Stmt, TreeView>(m, "Stmt") // NOLINT(bugprone-unused-raii)
155       .def(py::init([](const TreeView& thing) { return Stmt(thing.get()); }));
156   py::class_<Expr, TreeView>(m, "Expr"); // NOLINT(bugprone-unused-raii)
157   py::class_<Def, TreeView>(m, "Def")
158       .def(py::init(
159           [](const Ident& name, const Decl& decl, std::vector<Stmt> body) {
160             const auto& r = name.range();
161             return Def::create(r, name, decl, wrap_list(r, std::move(body)));
162           }))
163       .def("decl", [](const Def& def) { return def.decl(); })
164       .def("name", [](const Def& def) { return def.name(); });
165   py::class_<Property, TreeView>(m, "Property")
166       .def(py::init([](const SourceRange& r,
167                        const Ident& name,
168                        const Def& getter,
169                        Def* setter) {
170         return Property::create(r, name, getter, wrap_maybe(r, setter));
171       }))
172       .def("name", [](const Property& property) { return property.name(); })
173       .def(
174           "getter_name",
175           [](const Property& property) { return property.getter().name(); })
176       .def("setter_name", [](const Property& property) {
177         if (property.setter().present()) {
178           return std::optional<Ident>(property.setter().get().name());
179         }
180 
181         return std::optional<Ident>(std::nullopt);
182       });
183 
184   py::class_<ClassDef, TreeView>(m, "ClassDef")
185       .def(py::init([](const Ident& name,
186                        std::vector<Stmt> body,
187                        std::vector<Property> props,
188                        std::vector<Assign> assigns) {
189         const auto& r = name.range();
190         return ClassDef::create(
191             r,
192             name,
193             Maybe<Expr>::create(r),
194             wrap_list(r, std::move(body)),
195             wrap_list(r, std::move(props)),
196             wrap_list(r, std::move(assigns)));
197       }));
198 
199   py::class_<Decl, TreeView>(m, "Decl").def(py::init(
200       [](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
201         return Decl::create(
202             r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
203       }));
204 
205   py::class_<Delete, Stmt>(m, "Delete")
206       .def(py::init([](const SourceRange& range, std::vector<Expr> targets) {
207         return Delete::create(range, wrap_list(range, std::move(targets)));
208       }));
209 
210   py::class_<WithItem, Expr>(m, "WithItem")
211       .def(py::init([](const SourceRange& range, const Expr& target, Var* var) {
212         return WithItem::create(range, target, wrap_maybe(range, var));
213       }));
214 
215   py::class_<Assign, Stmt>(m, "Assign")
216       .def(py::init([](std::vector<Expr> lhs, const Expr& rhs) {
217         auto li = wrap_list(rhs.range(), std::move(lhs));
218         return Assign::create(
219             li.range(),
220             li,
221             Maybe<Expr>::create(rhs.range(), rhs),
222             Maybe<Expr>::create(li.range()));
223       }))
224       .def(py::init([](std::vector<Expr> lhs, const Expr& rhs, Expr* type) {
225         auto li = wrap_list(rhs.range(), std::move(lhs));
226         return Assign::create(
227             li.range(),
228             li,
229             Maybe<Expr>::create(rhs.range(), rhs),
230             wrap_maybe(li.range(), type));
231       }));
232   py::class_<AugAssign, Stmt>(m, "AugAssign")
233       .def(py::init(
234           [](const Expr& lhs, const std::string& kind_str, const Expr& rhs) {
235             const auto& r = lhs.range();
236             auto kind =
237                 AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
238             return AugAssign::create(r, lhs, kind, rhs);
239           }));
240   py::class_<Return, Stmt>(m, "Return")
241       .def(py::init([](const SourceRange& range, Expr* value) {
242         return Return::create(
243             range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
244       }));
245   py::class_<Raise, Stmt>(m, "Raise")
246       .def(py::init([](const SourceRange& range, const Expr& expr) {
247         return Raise::create(range, expr);
248       }));
249   py::class_<Assert, Stmt>(m, "Assert")
250       .def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) {
251         return Assert::create(range, test, wrap_maybe(range, msg));
252       }));
253   py::class_<Pass, Stmt>(m, "Pass").def(
254       py::init([](const SourceRange& range) { return Pass::create(range); }));
255   py::class_<Break, Stmt>(m, "Break")
256       .def(py::init(
257           [](const SourceRange& range) { return Break::create(range); }));
258   py::class_<Continue, Stmt>(m, "Continue")
259       .def(py::init(
260           [](const SourceRange& range) { return Continue::create(range); }));
261   py::class_<Dots, Expr>(m, "Dots").def(
262       py::init([](const SourceRange& range) { return Dots::create(range); }));
263   py::class_<If, Stmt>(m, "If").def(
264       py::init([](const SourceRange& range,
265                   const Expr& cond,
266                   std::vector<Stmt> true_branch,
267                   std::vector<Stmt> false_branch) {
268         return If::create(
269             range,
270             cond,
271             wrap_list(range, std::move(true_branch)),
272             wrap_list(range, std::move(false_branch)));
273       }));
274   py::class_<While, Stmt>(m, "While")
275       .def(py::init([](const SourceRange& range,
276                        const Expr& cond,
277                        std::vector<Stmt> body) {
278         return While::create(range, cond, wrap_list(range, std::move(body)));
279       }));
280   py::class_<With, Stmt>(m, "With").def(
281       py::init([](const SourceRange& range,
282                   std::vector<WithItem> targets,
283                   std::vector<Stmt> body) {
284         return With::create(
285             range,
286             wrap_list(range, std::move(targets)),
287             wrap_list(range, std::move(body)));
288       }));
289   py::class_<For, Stmt>(m, "For").def(py::init([](const SourceRange& range,
290                                                   std::vector<Expr>& targets,
291                                                   std::vector<Expr>& itrs,
292                                                   std::vector<Stmt> body) {
293     return For::create(
294         range,
295         wrap_list(range, std::move(targets)),
296         wrap_list(range, std::move(itrs)),
297         wrap_list(range, std::move(body)));
298   }));
299   py::class_<ExprStmt, Stmt>(m, "ExprStmt").def(py::init([](const Expr& expr) {
300     return ExprStmt::create(expr.range(), expr);
301   }));
302 
303   py::class_<Var, Expr>(m, "Var")
304       .def(py::init(
305           [](const Ident& name) { return Var::create(name.range(), name); }))
306       .def_property_readonly("name", [](const Var& var) { return var.name(); });
307   py::class_<BinOp, Expr>(m, "BinOp")
308       .def(py::init(
309           [](const std::string& kind, const Expr& lhs, const Expr& rhs) {
310             return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
311           }));
312   // NB: we take range here, because unary ops precede their exprs, so we need
313   // to include them
314   py::class_<UnaryOp, Expr>(m, "UnaryOp")
315       .def(py::init([](const SourceRange& range,
316                        const std::string& kind,
317                        const Expr& expr) {
318         auto resolved_kind = stringToKind(kind);
319         resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
320         return UnaryOp::create(range, resolved_kind, expr);
321       }));
322   py::class_<Const, Expr>(m, "Const")
323       .def(py::init([](const SourceRange& range, const std::string& value) {
324         return Const::create(range, value);
325       }));
326   py::class_<StringLiteral, Expr>(m, "StringLiteral")
327       .def(py::init([](const SourceRange& range, const std::string& value) {
328         return StringLiteral::create(range, value);
329       }));
330   py::class_<Apply, Expr>(m, "Apply")
331       .def(py::init([](const Expr& expr,
332                        std::vector<Expr> args,
333                        std::vector<Attribute> kwargs) {
334         const auto& r = expr.range();
335         return Apply::create(
336             expr.range(),
337             expr,
338             wrap_list(r, std::move(args)),
339             wrap_list(r, std::move(kwargs)));
340       }));
341   py::class_<Select, Expr>(m, "Select")
342       .def(py::init([](const Expr& expr, const Ident& field) {
343         return Select::create(expr.range(), expr, field);
344       }));
345   py::class_<TernaryIf, Expr>(m, "TernaryIf")
346       .def(py::init(
347           [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
348             return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
349           }));
350   py::class_<ListComp, Expr>(m, "ListComp")
351       .def(py::init([](const SourceRange& range,
352                        const Expr& elt,
353                        const Expr& target,
354                        const Expr& iter) {
355         return ListComp::create(range, elt, target, iter);
356       }));
357   py::class_<DictComp, Expr>(m, "DictComp")
358       .def(py::init([](const SourceRange& range,
359                        const Expr& key,
360                        const Expr& value,
361                        const Expr& target,
362                        const Expr& iter) {
363         return DictComp::create(range, key, value, target, iter);
364       }));
365   py::class_<ListLiteral, Expr>(m, "ListLiteral")
366       .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
367         return ListLiteral::create(range, wrap_list(range, std::move(args)));
368       }));
369   py::class_<TupleLiteral, Expr>(m, "TupleLiteral")
370       .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
371         return TupleLiteral::create(range, wrap_list(range, std::move(args)));
372       }));
373   py::class_<DictLiteral, Expr>(m, "DictLiteral")
374       .def(py::init([](const SourceRange& range,
375                        std::vector<Expr> keys,
376                        std::vector<Expr> values) {
377         return DictLiteral::create(
378             range,
379             wrap_list(range, std::move(keys)),
380             wrap_list(range, std::move(values)));
381       }));
382   py::class_<Subscript, Expr>(m, "Subscript")
383       .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
384         return Subscript::create(
385             base.range(),
386             base,
387             wrap_list(base.range(), std::move(subscript_exprs)));
388       }));
389   py::class_<SliceExpr, Expr>(m, "SliceExpr")
390       .def(py::init(
391           [](const SourceRange& range, Expr* lower, Expr* upper, Expr* step) {
392             return SliceExpr::create(
393                 range,
394                 wrap_maybe(range, lower),
395                 wrap_maybe(range, upper),
396                 wrap_maybe(range, step));
397           }));
398   py::class_<Starred, Expr>(m, "Starred")
399       .def(py::init([](const SourceRange& range, const Expr& expr) {
400         return Starred::create(range, expr);
401       }));
402   py::class_<Maybe<Expr>, TreeView>(m, "EmptyTypeAnnotation")
403       .def(py::init(
404           [](const SourceRange& range) { return Maybe<Expr>::create(range); }));
405 }
406 
407 } // namespace torch::jit
408