xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/python/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/python.h>
2 #include <torch/python/init.h>
3 
4 #include <torch/nn/module.h>
5 #include <torch/ordered_dict.h>
6 
7 #include <torch/csrc/utils/pybind.h>
8 
9 #include <string>
10 #include <vector>
11 
12 namespace py = pybind11;
13 
14 namespace pybind11 {
15 namespace detail {
16 #define ITEM_TYPE_CASTER(T, Name)                                             \
17   template <>                                                                 \
18   struct type_caster<typename torch::OrderedDict<std::string, T>::Item> {     \
19    public:                                                                    \
20     using Item = typename torch::OrderedDict<std::string, T>::Item;           \
21     using PairCaster = make_caster<std::pair<std::string, T>>;                \
22     PYBIND11_TYPE_CASTER(Item, _("Ordered" #Name "DictItem"));                \
23     bool load(handle src, bool convert) {                                     \
24       return PairCaster().load(src, convert);                                 \
25     }                                                                         \
26     static handle cast(Item src, return_value_policy policy, handle parent) { \
27       return PairCaster::cast(                                                \
28           src.pair(), std::move(policy), std::move(parent));                  \
29     }                                                                         \
30   }
31 
32 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
33 ITEM_TYPE_CASTER(torch::Tensor, Tensor);
34 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
35 ITEM_TYPE_CASTER(std::shared_ptr<torch::nn::Module>, Module);
36 } // namespace detail
37 } // namespace pybind11
38 
39 namespace torch {
40 namespace python {
41 namespace {
42 template <typename T>
bind_ordered_dict(py::module module,const char * dict_name)43 void bind_ordered_dict(py::module module, const char* dict_name) {
44   using ODict = OrderedDict<std::string, T>;
45   // clang-format off
46   py::class_<ODict>(module, dict_name)
47       .def("items", &ODict::items)
48       .def("keys", &ODict::keys)
49       .def("values", &ODict::values)
50       .def("__iter__", [](const ODict& dict) {
51             return py::make_iterator(dict.begin(), dict.end());
52           }, py::keep_alive<0, 1>())
53       .def("__len__", &ODict::size)
54       .def("__contains__", &ODict::contains)
55       .def("__getitem__", [](const ODict& dict, const std::string& key) {
56         return dict[key];
57       })
58       .def("__getitem__", [](const ODict& dict, size_t index) {
59         return dict[index];
60       });
61   // clang-format on
62 }
63 } // namespace
64 
init_bindings(PyObject * module)65 void init_bindings(PyObject* module) {
66   py::module m = py::handle(module).cast<py::module>();
67   py::module cpp = m.def_submodule("cpp");
68 
69   bind_ordered_dict<Tensor>(cpp, "OrderedTensorDict");
70   bind_ordered_dict<std::shared_ptr<nn::Module>>(cpp, "OrderedModuleDict");
71 
72   py::module nn = cpp.def_submodule("nn");
73   add_module_bindings(
74       py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn, "Module"));
75 }
76 } // namespace python
77 } // namespace torch
78