1 #include <torch/python.h>
2 #include <torch/python/init.h>
3
4 #include <torch/nn/module.h>
5 #include <torch/ordered_dict.h>
6
7 #include <torch/csrc/utils/pybind.h>
8
9 #include <string>
10 #include <vector>
11
12 namespace py = pybind11;
13
14 namespace pybind11 {
15 namespace detail {
16 #define ITEM_TYPE_CASTER(T, Name) \
17 template <> \
18 struct type_caster<typename torch::OrderedDict<std::string, T>::Item> { \
19 public: \
20 using Item = typename torch::OrderedDict<std::string, T>::Item; \
21 using PairCaster = make_caster<std::pair<std::string, T>>; \
22 PYBIND11_TYPE_CASTER(Item, _("Ordered" #Name "DictItem")); \
23 bool load(handle src, bool convert) { \
24 return PairCaster().load(src, convert); \
25 } \
26 static handle cast(Item src, return_value_policy policy, handle parent) { \
27 return PairCaster::cast( \
28 src.pair(), std::move(policy), std::move(parent)); \
29 } \
30 }
31
32 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
33 ITEM_TYPE_CASTER(torch::Tensor, Tensor);
34 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
35 ITEM_TYPE_CASTER(std::shared_ptr<torch::nn::Module>, Module);
36 } // namespace detail
37 } // namespace pybind11
38
39 namespace torch {
40 namespace python {
41 namespace {
42 template <typename T>
bind_ordered_dict(py::module module,const char * dict_name)43 void bind_ordered_dict(py::module module, const char* dict_name) {
44 using ODict = OrderedDict<std::string, T>;
45 // clang-format off
46 py::class_<ODict>(module, dict_name)
47 .def("items", &ODict::items)
48 .def("keys", &ODict::keys)
49 .def("values", &ODict::values)
50 .def("__iter__", [](const ODict& dict) {
51 return py::make_iterator(dict.begin(), dict.end());
52 }, py::keep_alive<0, 1>())
53 .def("__len__", &ODict::size)
54 .def("__contains__", &ODict::contains)
55 .def("__getitem__", [](const ODict& dict, const std::string& key) {
56 return dict[key];
57 })
58 .def("__getitem__", [](const ODict& dict, size_t index) {
59 return dict[index];
60 });
61 // clang-format on
62 }
63 } // namespace
64
init_bindings(PyObject * module)65 void init_bindings(PyObject* module) {
66 py::module m = py::handle(module).cast<py::module>();
67 py::module cpp = m.def_submodule("cpp");
68
69 bind_ordered_dict<Tensor>(cpp, "OrderedTensorDict");
70 bind_ordered_dict<std::shared_ptr<nn::Module>>(cpp, "OrderedModuleDict");
71
72 py::module nn = cpp.def_submodule("nn");
73 add_module_bindings(
74 py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn, "Module"));
75 }
76 } // namespace python
77 } // namespace torch
78