xref: /aosp_15_r20/external/pytorch/torch/custom_class.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/builtin_function.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/function_schema.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/ivalue.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/class_type.h>
7*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/op_registration/infer_schema.h>
8*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/stack.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/C++17.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Metaprogramming.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/util/TypeList.h>
12*da0073e9SAndroid Build Coastguard Worker #include <c10/util/TypeTraits.h>
13*da0073e9SAndroid Build Coastguard Worker #include <torch/custom_class_detail.h>
14*da0073e9SAndroid Build Coastguard Worker #include <torch/library.h>
15*da0073e9SAndroid Build Coastguard Worker #include <sstream>
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker namespace torch {
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker /// This function is used in conjunction with `class_::def()` to register
20*da0073e9SAndroid Build Coastguard Worker /// a constructor for a given C++ class type. For example,
21*da0073e9SAndroid Build Coastguard Worker /// `torch::init<int, std::string>()` would register a two-argument constructor
22*da0073e9SAndroid Build Coastguard Worker /// taking an `int` and a `std::string` as argument.
23*da0073e9SAndroid Build Coastguard Worker template <class... Types>
init()24*da0073e9SAndroid Build Coastguard Worker detail::types<void, Types...> init() {
25*da0073e9SAndroid Build Coastguard Worker   return detail::types<void, Types...>{};
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker template <typename Func, typename... ParameterTypeList>
29*da0073e9SAndroid Build Coastguard Worker struct InitLambda {
30*da0073e9SAndroid Build Coastguard Worker   Func f;
31*da0073e9SAndroid Build Coastguard Worker };
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker template <typename Func>
decltype(auto)34*da0073e9SAndroid Build Coastguard Worker decltype(auto) init(Func&& f) {
35*da0073e9SAndroid Build Coastguard Worker   using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>;
36*da0073e9SAndroid Build Coastguard Worker   using ParameterTypeList = typename InitTraits::parameter_types;
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker   InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
39*da0073e9SAndroid Build Coastguard Worker   return init;
40*da0073e9SAndroid Build Coastguard Worker }
41*da0073e9SAndroid Build Coastguard Worker 
42*da0073e9SAndroid Build Coastguard Worker /// Entry point for custom C++ class registration. To register a C++ class
43*da0073e9SAndroid Build Coastguard Worker /// in PyTorch, instantiate `torch::class_` with the desired class as the
44*da0073e9SAndroid Build Coastguard Worker /// template parameter. Typically, this instantiation should be done in
45*da0073e9SAndroid Build Coastguard Worker /// the initialization of a global variable, so that the class will be
46*da0073e9SAndroid Build Coastguard Worker /// made available on dynamic library loading without any additional API
47*da0073e9SAndroid Build Coastguard Worker /// calls needed. For example, to register a class named Foo, you might
48*da0073e9SAndroid Build Coastguard Worker /// create a global variable like so:
49*da0073e9SAndroid Build Coastguard Worker ///
50*da0073e9SAndroid Build Coastguard Worker ///     static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
51*da0073e9SAndroid Build Coastguard Worker ///       .def("myMethod", &Foo::myMethod)
52*da0073e9SAndroid Build Coastguard Worker ///       .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
53*da0073e9SAndroid Build Coastguard Worker ///         // Do something with `self`
54*da0073e9SAndroid Build Coastguard Worker ///       });
55*da0073e9SAndroid Build Coastguard Worker ///
56*da0073e9SAndroid Build Coastguard Worker /// In addition to registering the class, this registration also chains
57*da0073e9SAndroid Build Coastguard Worker /// `def()` calls to register methods. `myMethod()` is registered with
58*da0073e9SAndroid Build Coastguard Worker /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
59*da0073e9SAndroid Build Coastguard Worker /// is registered with a C++ lambda expression.
60*da0073e9SAndroid Build Coastguard Worker template <class CurClass>
61*da0073e9SAndroid Build Coastguard Worker class class_ : public ::torch::detail::class_base {
62*da0073e9SAndroid Build Coastguard Worker   static_assert(
63*da0073e9SAndroid Build Coastguard Worker       std::is_base_of_v<CustomClassHolder, CurClass>,
64*da0073e9SAndroid Build Coastguard Worker       "torch::class_<T> requires T to inherit from CustomClassHolder");
65*da0073e9SAndroid Build Coastguard Worker 
66*da0073e9SAndroid Build Coastguard Worker  public:
67*da0073e9SAndroid Build Coastguard Worker   /// This constructor actually registers the class type.
68*da0073e9SAndroid Build Coastguard Worker   /// String argument `namespaceName` is an identifier for the
69*da0073e9SAndroid Build Coastguard Worker   /// namespace you would like this class to appear in.
70*da0073e9SAndroid Build Coastguard Worker   /// String argument `className` is the name you would like to
71*da0073e9SAndroid Build Coastguard Worker   /// see this class exposed as in Python and TorchScript. For example, if
72*da0073e9SAndroid Build Coastguard Worker   /// you pass `foo` as the namespace name and `Bar` as the className, the
73*da0073e9SAndroid Build Coastguard Worker   /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
74*da0073e9SAndroid Build Coastguard Worker   explicit class_(
75*da0073e9SAndroid Build Coastguard Worker       const std::string& namespaceName,
76*da0073e9SAndroid Build Coastguard Worker       const std::string& className,
77*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "")
class_base(namespaceName,className,std::move (doc_string),typeid (c10::intrusive_ptr<CurClass>),typeid (c10::tagged_capsule<CurClass>))78*da0073e9SAndroid Build Coastguard Worker       : class_base(
79*da0073e9SAndroid Build Coastguard Worker             namespaceName,
80*da0073e9SAndroid Build Coastguard Worker             className,
81*da0073e9SAndroid Build Coastguard Worker             std::move(doc_string),
82*da0073e9SAndroid Build Coastguard Worker             typeid(c10::intrusive_ptr<CurClass>),
83*da0073e9SAndroid Build Coastguard Worker             typeid(c10::tagged_capsule<CurClass>)) {}
84*da0073e9SAndroid Build Coastguard Worker 
85*da0073e9SAndroid Build Coastguard Worker   /// def() can be used in conjunction with `torch::init()` to register
86*da0073e9SAndroid Build Coastguard Worker   /// a constructor for a given C++ class type. For example, passing
87*da0073e9SAndroid Build Coastguard Worker   /// `torch::init<int, std::string>()` would register a two-argument
88*da0073e9SAndroid Build Coastguard Worker   /// constructor taking an `int` and a `std::string` as argument.
89*da0073e9SAndroid Build Coastguard Worker   template <typename... Types>
90*da0073e9SAndroid Build Coastguard Worker   class_& def(
91*da0073e9SAndroid Build Coastguard Worker       torch::detail::types<void, Types...>,
92*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "",
93*da0073e9SAndroid Build Coastguard Worker       std::initializer_list<arg> default_args =
94*da0073e9SAndroid Build Coastguard Worker           {}) { // Used in combination with
95*da0073e9SAndroid Build Coastguard Worker     // torch::init<...>()
96*da0073e9SAndroid Build Coastguard Worker     auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
97*da0073e9SAndroid Build Coastguard Worker       auto classObj = c10::make_intrusive<CurClass>(args...);
98*da0073e9SAndroid Build Coastguard Worker       auto object = self.ivalue.toObject();
99*da0073e9SAndroid Build Coastguard Worker       object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
100*da0073e9SAndroid Build Coastguard Worker     };
101*da0073e9SAndroid Build Coastguard Worker 
102*da0073e9SAndroid Build Coastguard Worker     defineMethod(
103*da0073e9SAndroid Build Coastguard Worker         "__init__",
104*da0073e9SAndroid Build Coastguard Worker         std::move(func),
105*da0073e9SAndroid Build Coastguard Worker         std::move(doc_string),
106*da0073e9SAndroid Build Coastguard Worker         default_args);
107*da0073e9SAndroid Build Coastguard Worker     return *this;
108*da0073e9SAndroid Build Coastguard Worker   }
109*da0073e9SAndroid Build Coastguard Worker 
110*da0073e9SAndroid Build Coastguard Worker   // Used in combination with torch::init([]lambda(){......})
111*da0073e9SAndroid Build Coastguard Worker   template <typename Func, typename... ParameterTypes>
112*da0073e9SAndroid Build Coastguard Worker   class_& def(
113*da0073e9SAndroid Build Coastguard Worker       InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
114*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "",
115*da0073e9SAndroid Build Coastguard Worker       std::initializer_list<arg> default_args = {}) {
116*da0073e9SAndroid Build Coastguard Worker     auto init_lambda_wrapper = [func = std::move(init.f)](
117*da0073e9SAndroid Build Coastguard Worker                                    c10::tagged_capsule<CurClass> self,
118*da0073e9SAndroid Build Coastguard Worker                                    ParameterTypes... arg) {
119*da0073e9SAndroid Build Coastguard Worker       c10::intrusive_ptr<CurClass> classObj =
120*da0073e9SAndroid Build Coastguard Worker           at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
121*da0073e9SAndroid Build Coastguard Worker       auto object = self.ivalue.toObject();
122*da0073e9SAndroid Build Coastguard Worker       object->setSlot(0, c10::IValue::make_capsule(classObj));
123*da0073e9SAndroid Build Coastguard Worker     };
124*da0073e9SAndroid Build Coastguard Worker 
125*da0073e9SAndroid Build Coastguard Worker     defineMethod(
126*da0073e9SAndroid Build Coastguard Worker         "__init__",
127*da0073e9SAndroid Build Coastguard Worker         std::move(init_lambda_wrapper),
128*da0073e9SAndroid Build Coastguard Worker         std::move(doc_string),
129*da0073e9SAndroid Build Coastguard Worker         default_args);
130*da0073e9SAndroid Build Coastguard Worker 
131*da0073e9SAndroid Build Coastguard Worker     return *this;
132*da0073e9SAndroid Build Coastguard Worker   }
133*da0073e9SAndroid Build Coastguard Worker 
134*da0073e9SAndroid Build Coastguard Worker   /// This is the normal method registration API. `name` is the name that
135*da0073e9SAndroid Build Coastguard Worker   /// the method will be made accessible by in Python and TorchScript.
136*da0073e9SAndroid Build Coastguard Worker   /// `f` is a callable object that defines the method. Typically `f`
137*da0073e9SAndroid Build Coastguard Worker   /// will either be a pointer to a method on `CurClass`, or a lambda
138*da0073e9SAndroid Build Coastguard Worker   /// expression that takes a `c10::intrusive_ptr<CurClass>` as the first
139*da0073e9SAndroid Build Coastguard Worker   /// argument (emulating a `this` argument in a C++ method.)
140*da0073e9SAndroid Build Coastguard Worker   ///
141*da0073e9SAndroid Build Coastguard Worker   /// Examples:
142*da0073e9SAndroid Build Coastguard Worker   ///
143*da0073e9SAndroid Build Coastguard Worker   ///     // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
144*da0073e9SAndroid Build Coastguard Worker   ///     // Python and TorchScript
145*da0073e9SAndroid Build Coastguard Worker   ///     .def("call_foo", &Foo::foo)
146*da0073e9SAndroid Build Coastguard Worker   ///
147*da0073e9SAndroid Build Coastguard Worker   ///     // Exposes the given lambda expression as method `call_lambda()`
148*da0073e9SAndroid Build Coastguard Worker   ///     // in Python and TorchScript.
149*da0073e9SAndroid Build Coastguard Worker   ///     .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
150*da0073e9SAndroid Build Coastguard Worker   ///       // do something
151*da0073e9SAndroid Build Coastguard Worker   ///     })
152*da0073e9SAndroid Build Coastguard Worker   template <typename Func>
153*da0073e9SAndroid Build Coastguard Worker   class_& def(
154*da0073e9SAndroid Build Coastguard Worker       std::string name,
155*da0073e9SAndroid Build Coastguard Worker       Func f,
156*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "",
157*da0073e9SAndroid Build Coastguard Worker       std::initializer_list<arg> default_args = {}) {
158*da0073e9SAndroid Build Coastguard Worker     auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
159*da0073e9SAndroid Build Coastguard Worker     defineMethod(
160*da0073e9SAndroid Build Coastguard Worker         std::move(name),
161*da0073e9SAndroid Build Coastguard Worker         std::move(wrapped_f),
162*da0073e9SAndroid Build Coastguard Worker         std::move(doc_string),
163*da0073e9SAndroid Build Coastguard Worker         default_args);
164*da0073e9SAndroid Build Coastguard Worker     return *this;
165*da0073e9SAndroid Build Coastguard Worker   }
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker   /// Method registration API for static methods.
168*da0073e9SAndroid Build Coastguard Worker   template <typename Func>
169*da0073e9SAndroid Build Coastguard Worker   class_& def_static(std::string name, Func func, std::string doc_string = "") {
170*da0073e9SAndroid Build Coastguard Worker     auto qualMethodName = qualClassName + "." + name;
171*da0073e9SAndroid Build Coastguard Worker     auto schema =
172*da0073e9SAndroid Build Coastguard Worker         c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
173*da0073e9SAndroid Build Coastguard Worker 
174*da0073e9SAndroid Build Coastguard Worker     auto wrapped_func =
175*da0073e9SAndroid Build Coastguard Worker         [func = std::move(func)](jit::Stack& stack) mutable -> void {
176*da0073e9SAndroid Build Coastguard Worker       using RetType =
177*da0073e9SAndroid Build Coastguard Worker           typename c10::guts::infer_function_traits_t<Func>::return_type;
178*da0073e9SAndroid Build Coastguard Worker       detail::BoxedProxy<RetType, Func>()(stack, func);
179*da0073e9SAndroid Build Coastguard Worker     };
180*da0073e9SAndroid Build Coastguard Worker     auto method = std::make_unique<jit::BuiltinOpFunction>(
181*da0073e9SAndroid Build Coastguard Worker         std::move(qualMethodName),
182*da0073e9SAndroid Build Coastguard Worker         std::move(schema),
183*da0073e9SAndroid Build Coastguard Worker         std::move(wrapped_func),
184*da0073e9SAndroid Build Coastguard Worker         std::move(doc_string));
185*da0073e9SAndroid Build Coastguard Worker 
186*da0073e9SAndroid Build Coastguard Worker     classTypePtr->addStaticMethod(method.get());
187*da0073e9SAndroid Build Coastguard Worker     registerCustomClassMethod(std::move(method));
188*da0073e9SAndroid Build Coastguard Worker     return *this;
189*da0073e9SAndroid Build Coastguard Worker   }
190*da0073e9SAndroid Build Coastguard Worker 
191*da0073e9SAndroid Build Coastguard Worker   /// Property registration API for properties with both getter and setter
192*da0073e9SAndroid Build Coastguard Worker   /// functions.
193*da0073e9SAndroid Build Coastguard Worker   template <typename GetterFunc, typename SetterFunc>
194*da0073e9SAndroid Build Coastguard Worker   class_& def_property(
195*da0073e9SAndroid Build Coastguard Worker       const std::string& name,
196*da0073e9SAndroid Build Coastguard Worker       GetterFunc getter_func,
197*da0073e9SAndroid Build Coastguard Worker       SetterFunc setter_func,
198*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "") {
199*da0073e9SAndroid Build Coastguard Worker     torch::jit::Function* getter{};
200*da0073e9SAndroid Build Coastguard Worker     torch::jit::Function* setter{};
201*da0073e9SAndroid Build Coastguard Worker 
202*da0073e9SAndroid Build Coastguard Worker     auto wrapped_getter =
203*da0073e9SAndroid Build Coastguard Worker         detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
204*da0073e9SAndroid Build Coastguard Worker     getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
205*da0073e9SAndroid Build Coastguard Worker 
206*da0073e9SAndroid Build Coastguard Worker     auto wrapped_setter =
207*da0073e9SAndroid Build Coastguard Worker         detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func));
208*da0073e9SAndroid Build Coastguard Worker     setter = defineMethod(name + "_setter", wrapped_setter, doc_string);
209*da0073e9SAndroid Build Coastguard Worker 
210*da0073e9SAndroid Build Coastguard Worker     classTypePtr->addProperty(name, getter, setter);
211*da0073e9SAndroid Build Coastguard Worker     return *this;
212*da0073e9SAndroid Build Coastguard Worker   }
213*da0073e9SAndroid Build Coastguard Worker 
214*da0073e9SAndroid Build Coastguard Worker   /// Property registration API for properties with only getter function.
215*da0073e9SAndroid Build Coastguard Worker   template <typename GetterFunc>
216*da0073e9SAndroid Build Coastguard Worker   class_& def_property(
217*da0073e9SAndroid Build Coastguard Worker       const std::string& name,
218*da0073e9SAndroid Build Coastguard Worker       GetterFunc getter_func,
219*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "") {
220*da0073e9SAndroid Build Coastguard Worker     torch::jit::Function* getter{};
221*da0073e9SAndroid Build Coastguard Worker 
222*da0073e9SAndroid Build Coastguard Worker     auto wrapped_getter =
223*da0073e9SAndroid Build Coastguard Worker         detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
224*da0073e9SAndroid Build Coastguard Worker     getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
225*da0073e9SAndroid Build Coastguard Worker 
226*da0073e9SAndroid Build Coastguard Worker     classTypePtr->addProperty(name, getter, nullptr);
227*da0073e9SAndroid Build Coastguard Worker     return *this;
228*da0073e9SAndroid Build Coastguard Worker   }
229*da0073e9SAndroid Build Coastguard Worker 
230*da0073e9SAndroid Build Coastguard Worker   /// Property registration API for properties with read-write access.
231*da0073e9SAndroid Build Coastguard Worker   template <typename T>
def_readwrite(const std::string & name,T CurClass::* field)232*da0073e9SAndroid Build Coastguard Worker   class_& def_readwrite(const std::string& name, T CurClass::*field) {
233*da0073e9SAndroid Build Coastguard Worker     auto getter_func = [field =
234*da0073e9SAndroid Build Coastguard Worker                             field](const c10::intrusive_ptr<CurClass>& self) {
235*da0073e9SAndroid Build Coastguard Worker       return self.get()->*field;
236*da0073e9SAndroid Build Coastguard Worker     };
237*da0073e9SAndroid Build Coastguard Worker 
238*da0073e9SAndroid Build Coastguard Worker     auto setter_func = [field = field](
239*da0073e9SAndroid Build Coastguard Worker                            const c10::intrusive_ptr<CurClass>& self, T value) {
240*da0073e9SAndroid Build Coastguard Worker       self.get()->*field = value;
241*da0073e9SAndroid Build Coastguard Worker     };
242*da0073e9SAndroid Build Coastguard Worker 
243*da0073e9SAndroid Build Coastguard Worker     return def_property(name, getter_func, setter_func);
244*da0073e9SAndroid Build Coastguard Worker   }
245*da0073e9SAndroid Build Coastguard Worker 
246*da0073e9SAndroid Build Coastguard Worker   /// Property registration API for properties with read-only access.
247*da0073e9SAndroid Build Coastguard Worker   template <typename T>
def_readonly(const std::string & name,T CurClass::* field)248*da0073e9SAndroid Build Coastguard Worker   class_& def_readonly(const std::string& name, T CurClass::*field) {
249*da0073e9SAndroid Build Coastguard Worker     auto getter_func =
250*da0073e9SAndroid Build Coastguard Worker         [field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) {
251*da0073e9SAndroid Build Coastguard Worker           return self.get()->*field;
252*da0073e9SAndroid Build Coastguard Worker         };
253*da0073e9SAndroid Build Coastguard Worker 
254*da0073e9SAndroid Build Coastguard Worker     return def_property(name, getter_func);
255*da0073e9SAndroid Build Coastguard Worker   }
256*da0073e9SAndroid Build Coastguard Worker 
257*da0073e9SAndroid Build Coastguard Worker   /// This is an unsafe method registration API added for adding custom JIT
258*da0073e9SAndroid Build Coastguard Worker   /// backend support via custom C++ classes. It is not for general purpose use.
259*da0073e9SAndroid Build Coastguard Worker   class_& _def_unboxed(
260*da0073e9SAndroid Build Coastguard Worker       const std::string& name,
261*da0073e9SAndroid Build Coastguard Worker       std::function<void(jit::Stack&)> func,
262*da0073e9SAndroid Build Coastguard Worker       c10::FunctionSchema schema,
263*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "") {
264*da0073e9SAndroid Build Coastguard Worker     auto method = std::make_unique<jit::BuiltinOpFunction>(
265*da0073e9SAndroid Build Coastguard Worker         qualClassName + "." + name,
266*da0073e9SAndroid Build Coastguard Worker         std::move(schema),
267*da0073e9SAndroid Build Coastguard Worker         std::move(func),
268*da0073e9SAndroid Build Coastguard Worker         std::move(doc_string));
269*da0073e9SAndroid Build Coastguard Worker     classTypePtr->addMethod(method.get());
270*da0073e9SAndroid Build Coastguard Worker     registerCustomClassMethod(std::move(method));
271*da0073e9SAndroid Build Coastguard Worker     return *this;
272*da0073e9SAndroid Build Coastguard Worker   }
273*da0073e9SAndroid Build Coastguard Worker 
274*da0073e9SAndroid Build Coastguard Worker   /// def_pickle() is used to define exactly what state gets serialized
275*da0073e9SAndroid Build Coastguard Worker   /// or deserialized for a given instance of a custom C++ class in
276*da0073e9SAndroid Build Coastguard Worker   /// Python or TorchScript. This protocol is equivalent to the Pickle
277*da0073e9SAndroid Build Coastguard Worker   /// concept of `__getstate__` and `__setstate__` from Python
278*da0073e9SAndroid Build Coastguard Worker   /// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
279*da0073e9SAndroid Build Coastguard Worker   ///
280*da0073e9SAndroid Build Coastguard Worker   /// Currently, both the `get_state` and `set_state` callables must be
281*da0073e9SAndroid Build Coastguard Worker   /// C++ lambda expressions. They should have the following signatures,
282*da0073e9SAndroid Build Coastguard Worker   /// where `CurClass` is the class you're registering and `T1` is some object
283*da0073e9SAndroid Build Coastguard Worker   /// that encapsulates the state of the object.
284*da0073e9SAndroid Build Coastguard Worker   ///
285*da0073e9SAndroid Build Coastguard Worker   ///     __getstate__(intrusive_ptr<CurClass>) -> T1
286*da0073e9SAndroid Build Coastguard Worker   ///     __setstate__(T2) -> intrusive_ptr<CurClass>
287*da0073e9SAndroid Build Coastguard Worker   ///
288*da0073e9SAndroid Build Coastguard Worker   /// `T1` must be an object that is convertable to IValue by the same rules
289*da0073e9SAndroid Build Coastguard Worker   /// for custom op/method registration.
290*da0073e9SAndroid Build Coastguard Worker   ///
291*da0073e9SAndroid Build Coastguard Worker   /// For the common case, T1 == T2. T1 can also be a subtype of T2. An
292*da0073e9SAndroid Build Coastguard Worker   /// example where it makes sense for T1 and T2 to differ is if __setstate__
293*da0073e9SAndroid Build Coastguard Worker   /// handles legacy formats in a backwards compatible way.
294*da0073e9SAndroid Build Coastguard Worker   ///
295*da0073e9SAndroid Build Coastguard Worker   /// Example:
296*da0073e9SAndroid Build Coastguard Worker   ///
297*da0073e9SAndroid Build Coastguard Worker   ///     .def_pickle(
298*da0073e9SAndroid Build Coastguard Worker   ///         // __getstate__
299*da0073e9SAndroid Build Coastguard Worker   ///         [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
300*da0073e9SAndroid Build Coastguard Worker   ///           return self->stack_;
301*da0073e9SAndroid Build Coastguard Worker   ///         },
302*da0073e9SAndroid Build Coastguard Worker   ///         [](std::vector<std::string> state) { // __setstate__
303*da0073e9SAndroid Build Coastguard Worker   ///            return c10::make_intrusive<MyStackClass<std::string>>(
304*da0073e9SAndroid Build Coastguard Worker   ///               std::vector<std::string>{"i", "was", "deserialized"});
305*da0073e9SAndroid Build Coastguard Worker   ///         })
306*da0073e9SAndroid Build Coastguard Worker   template <typename GetStateFn, typename SetStateFn>
307*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
def_pickle(GetStateFn && get_state,SetStateFn && set_state)308*da0073e9SAndroid Build Coastguard Worker   class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
309*da0073e9SAndroid Build Coastguard Worker     static_assert(
310*da0073e9SAndroid Build Coastguard Worker         c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
311*da0073e9SAndroid Build Coastguard Worker             c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
312*da0073e9SAndroid Build Coastguard Worker         "def_pickle() currently only supports lambdas as "
313*da0073e9SAndroid Build Coastguard Worker         "__getstate__ and __setstate__ arguments.");
314*da0073e9SAndroid Build Coastguard Worker     def("__getstate__", std::forward<GetStateFn>(get_state));
315*da0073e9SAndroid Build Coastguard Worker 
316*da0073e9SAndroid Build Coastguard Worker     // __setstate__ needs to be registered with some custom handling:
317*da0073e9SAndroid Build Coastguard Worker     // We need to wrap the invocation of the user-provided function
318*da0073e9SAndroid Build Coastguard Worker     // such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
319*da0073e9SAndroid Build Coastguard Worker     // and assign it to the `capsule` attribute.
320*da0073e9SAndroid Build Coastguard Worker     using SetStateTraits =
321*da0073e9SAndroid Build Coastguard Worker         c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
322*da0073e9SAndroid Build Coastguard Worker     using SetStateArg = typename c10::guts::typelist::head_t<
323*da0073e9SAndroid Build Coastguard Worker         typename SetStateTraits::parameter_types>;
324*da0073e9SAndroid Build Coastguard Worker     auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)](
325*da0073e9SAndroid Build Coastguard Worker                                 c10::tagged_capsule<CurClass> self,
326*da0073e9SAndroid Build Coastguard Worker                                 SetStateArg arg) {
327*da0073e9SAndroid Build Coastguard Worker       c10::intrusive_ptr<CurClass> classObj =
328*da0073e9SAndroid Build Coastguard Worker           at::guts::invoke(set_state, std::move(arg));
329*da0073e9SAndroid Build Coastguard Worker       auto object = self.ivalue.toObject();
330*da0073e9SAndroid Build Coastguard Worker       object->setSlot(0, c10::IValue::make_capsule(classObj));
331*da0073e9SAndroid Build Coastguard Worker     };
332*da0073e9SAndroid Build Coastguard Worker     defineMethod(
333*da0073e9SAndroid Build Coastguard Worker         "__setstate__",
334*da0073e9SAndroid Build Coastguard Worker         detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
335*da0073e9SAndroid Build Coastguard Worker             std::move(setstate_wrapper)));
336*da0073e9SAndroid Build Coastguard Worker 
337*da0073e9SAndroid Build Coastguard Worker     // type validation
338*da0073e9SAndroid Build Coastguard Worker     auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema();
339*da0073e9SAndroid Build Coastguard Worker #ifndef STRIP_ERROR_MESSAGES
340*da0073e9SAndroid Build Coastguard Worker     auto format_getstate_schema = [&getstate_schema]() {
341*da0073e9SAndroid Build Coastguard Worker       std::stringstream ss;
342*da0073e9SAndroid Build Coastguard Worker       ss << getstate_schema;
343*da0073e9SAndroid Build Coastguard Worker       return ss.str();
344*da0073e9SAndroid Build Coastguard Worker     };
345*da0073e9SAndroid Build Coastguard Worker #endif
346*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
347*da0073e9SAndroid Build Coastguard Worker         getstate_schema.arguments().size() == 1,
348*da0073e9SAndroid Build Coastguard Worker         "__getstate__ should take exactly one argument: self. Got: ",
349*da0073e9SAndroid Build Coastguard Worker         format_getstate_schema());
350*da0073e9SAndroid Build Coastguard Worker     auto first_arg_type = getstate_schema.arguments().at(0).type();
351*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
352*da0073e9SAndroid Build Coastguard Worker         *first_arg_type == *classTypePtr,
353*da0073e9SAndroid Build Coastguard Worker         "self argument of __getstate__ must be the custom class type. Got ",
354*da0073e9SAndroid Build Coastguard Worker         first_arg_type->repr_str());
355*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
356*da0073e9SAndroid Build Coastguard Worker         getstate_schema.returns().size() == 1,
357*da0073e9SAndroid Build Coastguard Worker         "__getstate__ should return exactly one value for serialization. Got: ",
358*da0073e9SAndroid Build Coastguard Worker         format_getstate_schema());
359*da0073e9SAndroid Build Coastguard Worker 
360*da0073e9SAndroid Build Coastguard Worker     auto ser_type = getstate_schema.returns().at(0).type();
361*da0073e9SAndroid Build Coastguard Worker     auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema();
362*da0073e9SAndroid Build Coastguard Worker     auto arg_type = setstate_schema.arguments().at(1).type();
363*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
364*da0073e9SAndroid Build Coastguard Worker         ser_type->isSubtypeOf(*arg_type),
365*da0073e9SAndroid Build Coastguard Worker         "__getstate__'s return type should be a subtype of "
366*da0073e9SAndroid Build Coastguard Worker         "input argument of __setstate__. Got ",
367*da0073e9SAndroid Build Coastguard Worker         ser_type->repr_str(),
368*da0073e9SAndroid Build Coastguard Worker         " but expected ",
369*da0073e9SAndroid Build Coastguard Worker         arg_type->repr_str());
370*da0073e9SAndroid Build Coastguard Worker 
371*da0073e9SAndroid Build Coastguard Worker     return *this;
372*da0073e9SAndroid Build Coastguard Worker   }
373*da0073e9SAndroid Build Coastguard Worker 
374*da0073e9SAndroid Build Coastguard Worker  private:
375*da0073e9SAndroid Build Coastguard Worker   template <typename Func>
376*da0073e9SAndroid Build Coastguard Worker   torch::jit::Function* defineMethod(
377*da0073e9SAndroid Build Coastguard Worker       std::string name,
378*da0073e9SAndroid Build Coastguard Worker       Func func,
379*da0073e9SAndroid Build Coastguard Worker       std::string doc_string = "",
380*da0073e9SAndroid Build Coastguard Worker       std::initializer_list<arg> default_args = {}) {
381*da0073e9SAndroid Build Coastguard Worker     auto qualMethodName = qualClassName + "." + name;
382*da0073e9SAndroid Build Coastguard Worker     auto schema =
383*da0073e9SAndroid Build Coastguard Worker         c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
384*da0073e9SAndroid Build Coastguard Worker 
385*da0073e9SAndroid Build Coastguard Worker     // If default values are provided for function arguments, there must be
386*da0073e9SAndroid Build Coastguard Worker     // none (no default values) or default values for all function
387*da0073e9SAndroid Build Coastguard Worker     // arguments, except for self. This is because argument names are not
388*da0073e9SAndroid Build Coastguard Worker     // extracted by inferFunctionSchemaSingleReturn, and so there must be a
389*da0073e9SAndroid Build Coastguard Worker     // torch::arg instance in default_args even for arguments that do not
390*da0073e9SAndroid Build Coastguard Worker     // have an actual default value provided.
391*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
392*da0073e9SAndroid Build Coastguard Worker         default_args.size() == 0 ||
393*da0073e9SAndroid Build Coastguard Worker             default_args.size() == schema.arguments().size() - 1,
394*da0073e9SAndroid Build Coastguard Worker         "Default values must be specified for none or all arguments");
395*da0073e9SAndroid Build Coastguard Worker 
396*da0073e9SAndroid Build Coastguard Worker     // If there are default args, copy the argument names and default values to
397*da0073e9SAndroid Build Coastguard Worker     // the function schema.
398*da0073e9SAndroid Build Coastguard Worker     if (default_args.size() > 0) {
399*da0073e9SAndroid Build Coastguard Worker       schema = withNewArguments(schema, default_args);
400*da0073e9SAndroid Build Coastguard Worker     }
401*da0073e9SAndroid Build Coastguard Worker 
402*da0073e9SAndroid Build Coastguard Worker     auto wrapped_func =
403*da0073e9SAndroid Build Coastguard Worker         [func = std::move(func)](jit::Stack& stack) mutable -> void {
404*da0073e9SAndroid Build Coastguard Worker       // TODO: we need to figure out how to profile calls to custom functions
405*da0073e9SAndroid Build Coastguard Worker       // like this! Currently can't do it because the profiler stuff is in
406*da0073e9SAndroid Build Coastguard Worker       // libtorch and not ATen
407*da0073e9SAndroid Build Coastguard Worker       using RetType =
408*da0073e9SAndroid Build Coastguard Worker           typename c10::guts::infer_function_traits_t<Func>::return_type;
409*da0073e9SAndroid Build Coastguard Worker       detail::BoxedProxy<RetType, Func>()(stack, func);
410*da0073e9SAndroid Build Coastguard Worker     };
411*da0073e9SAndroid Build Coastguard Worker     auto method = std::make_unique<jit::BuiltinOpFunction>(
412*da0073e9SAndroid Build Coastguard Worker         qualMethodName,
413*da0073e9SAndroid Build Coastguard Worker         std::move(schema),
414*da0073e9SAndroid Build Coastguard Worker         std::move(wrapped_func),
415*da0073e9SAndroid Build Coastguard Worker         std::move(doc_string));
416*da0073e9SAndroid Build Coastguard Worker 
417*da0073e9SAndroid Build Coastguard Worker     // Register the method here to keep the Method alive.
418*da0073e9SAndroid Build Coastguard Worker     // ClassTypes do not hold ownership of their methods (normally it
419*da0073e9SAndroid Build Coastguard Worker     // those are held by the CompilationUnit), so we need a proxy for
420*da0073e9SAndroid Build Coastguard Worker     // that behavior here.
421*da0073e9SAndroid Build Coastguard Worker     auto method_val = method.get();
422*da0073e9SAndroid Build Coastguard Worker     classTypePtr->addMethod(method_val);
423*da0073e9SAndroid Build Coastguard Worker     registerCustomClassMethod(std::move(method));
424*da0073e9SAndroid Build Coastguard Worker     return method_val;
425*da0073e9SAndroid Build Coastguard Worker   }
426*da0073e9SAndroid Build Coastguard Worker };
427*da0073e9SAndroid Build Coastguard Worker 
428*da0073e9SAndroid Build Coastguard Worker /// make_custom_class() is a convenient way to create an instance of a
429*da0073e9SAndroid Build Coastguard Worker /// registered custom class and wrap it in an IValue, for example when you want
430*da0073e9SAndroid Build Coastguard Worker /// to pass the object to TorchScript. Its syntax is equivalent to APIs like
431*da0073e9SAndroid Build Coastguard Worker /// `std::make_shared<>` or `c10::make_intrusive<>`.
432*da0073e9SAndroid Build Coastguard Worker ///
433*da0073e9SAndroid Build Coastguard Worker /// For example, if you have a custom C++ class that can be constructed from an
434*da0073e9SAndroid Build Coastguard Worker /// `int` and `std::string`, you might use this API like so:
435*da0073e9SAndroid Build Coastguard Worker ///
436*da0073e9SAndroid Build Coastguard Worker ///     IValue custom_class_iv = torch::make_custom_class<MyClass>(3,
437*da0073e9SAndroid Build Coastguard Worker ///     "foobarbaz");
438*da0073e9SAndroid Build Coastguard Worker template <typename CurClass, typename... CtorArgs>
make_custom_class(CtorArgs &&...args)439*da0073e9SAndroid Build Coastguard Worker c10::IValue make_custom_class(CtorArgs&&... args) {
440*da0073e9SAndroid Build Coastguard Worker   auto userClassInstance =
441*da0073e9SAndroid Build Coastguard Worker       c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
442*da0073e9SAndroid Build Coastguard Worker   return c10::IValue(std::move(userClassInstance));
443*da0073e9SAndroid Build Coastguard Worker }
444*da0073e9SAndroid Build Coastguard Worker 
445*da0073e9SAndroid Build Coastguard Worker // Alternative api for creating a torchbind class over torch::class_ this api is
446*da0073e9SAndroid Build Coastguard Worker // preffered to prevent size regressions on Edge usecases. Must be used in
447*da0073e9SAndroid Build Coastguard Worker // conjunction with TORCH_SELECTIVE_CLASS macro aka
448*da0073e9SAndroid Build Coastguard Worker // selective_class<foo>("foo_namespace", TORCH_SELECTIVE_CLASS("foo"))
449*da0073e9SAndroid Build Coastguard Worker template <class CurClass>
selective_class_(const std::string & namespace_name,detail::SelectiveStr<true> className)450*da0073e9SAndroid Build Coastguard Worker inline class_<CurClass> selective_class_(
451*da0073e9SAndroid Build Coastguard Worker     const std::string& namespace_name,
452*da0073e9SAndroid Build Coastguard Worker     detail::SelectiveStr<true> className) {
453*da0073e9SAndroid Build Coastguard Worker   auto class_name = std::string(className.operator const char*());
454*da0073e9SAndroid Build Coastguard Worker   return torch::class_<CurClass>(namespace_name, class_name);
455*da0073e9SAndroid Build Coastguard Worker }
456*da0073e9SAndroid Build Coastguard Worker 
457*da0073e9SAndroid Build Coastguard Worker template <class CurClass>
selective_class_(const std::string &,detail::SelectiveStr<false>)458*da0073e9SAndroid Build Coastguard Worker inline detail::ClassNotSelected selective_class_(
459*da0073e9SAndroid Build Coastguard Worker     const std::string&,
460*da0073e9SAndroid Build Coastguard Worker     detail::SelectiveStr<false>) {
461*da0073e9SAndroid Build Coastguard Worker   return detail::ClassNotSelected();
462*da0073e9SAndroid Build Coastguard Worker }
463*da0073e9SAndroid Build Coastguard Worker 
464*da0073e9SAndroid Build Coastguard Worker // jit namespace for backward-compatibility
465*da0073e9SAndroid Build Coastguard Worker // We previously defined everything in torch::jit but moved it out to
466*da0073e9SAndroid Build Coastguard Worker // better reflect that these features are not limited only to TorchScript
467*da0073e9SAndroid Build Coastguard Worker namespace jit {
468*da0073e9SAndroid Build Coastguard Worker 
469*da0073e9SAndroid Build Coastguard Worker using ::torch::class_;
470*da0073e9SAndroid Build Coastguard Worker using ::torch::getCustomClass;
471*da0073e9SAndroid Build Coastguard Worker using ::torch::init;
472*da0073e9SAndroid Build Coastguard Worker using ::torch::isCustomClass;
473*da0073e9SAndroid Build Coastguard Worker 
474*da0073e9SAndroid Build Coastguard Worker } // namespace jit
475*da0073e9SAndroid Build Coastguard Worker 
476*da0073e9SAndroid Build Coastguard Worker template <class CurClass>
class_(const std::string & className)477*da0073e9SAndroid Build Coastguard Worker inline class_<CurClass> Library::class_(const std::string& className) {
478*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
479*da0073e9SAndroid Build Coastguard Worker       kind_ == DEF || kind_ == FRAGMENT,
480*da0073e9SAndroid Build Coastguard Worker       "class_(\"",
481*da0073e9SAndroid Build Coastguard Worker       className,
482*da0073e9SAndroid Build Coastguard Worker       "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block.  "
483*da0073e9SAndroid Build Coastguard Worker       "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace.  "
484*da0073e9SAndroid Build Coastguard Worker       "(Error occurred at ",
485*da0073e9SAndroid Build Coastguard Worker       file_,
486*da0073e9SAndroid Build Coastguard Worker       ":",
487*da0073e9SAndroid Build Coastguard Worker       line_,
488*da0073e9SAndroid Build Coastguard Worker       ")");
489*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
490*da0073e9SAndroid Build Coastguard Worker   return torch::class_<CurClass>(*ns_, className);
491*da0073e9SAndroid Build Coastguard Worker }
492*da0073e9SAndroid Build Coastguard Worker 
493*da0073e9SAndroid Build Coastguard Worker const std::unordered_set<std::string> getAllCustomClassesNames();
494*da0073e9SAndroid Build Coastguard Worker 
495*da0073e9SAndroid Build Coastguard Worker template <class CurClass>
class_(detail::SelectiveStr<true> className)496*da0073e9SAndroid Build Coastguard Worker inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) {
497*da0073e9SAndroid Build Coastguard Worker   auto class_name = std::string(className.operator const char*());
498*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
499*da0073e9SAndroid Build Coastguard Worker       kind_ == DEF || kind_ == FRAGMENT,
500*da0073e9SAndroid Build Coastguard Worker       "class_(\"",
501*da0073e9SAndroid Build Coastguard Worker       class_name,
502*da0073e9SAndroid Build Coastguard Worker       "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block.  "
503*da0073e9SAndroid Build Coastguard Worker       "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace.  "
504*da0073e9SAndroid Build Coastguard Worker       "(Error occurred at ",
505*da0073e9SAndroid Build Coastguard Worker       file_,
506*da0073e9SAndroid Build Coastguard Worker       ":",
507*da0073e9SAndroid Build Coastguard Worker       line_,
508*da0073e9SAndroid Build Coastguard Worker       ")");
509*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
510*da0073e9SAndroid Build Coastguard Worker   return torch::class_<CurClass>(*ns_, class_name);
511*da0073e9SAndroid Build Coastguard Worker }
512*da0073e9SAndroid Build Coastguard Worker 
513*da0073e9SAndroid Build Coastguard Worker template <class CurClass>
class_(detail::SelectiveStr<false>)514*da0073e9SAndroid Build Coastguard Worker inline detail::ClassNotSelected Library::class_(detail::SelectiveStr<false>) {
515*da0073e9SAndroid Build Coastguard Worker   return detail::ClassNotSelected();
516*da0073e9SAndroid Build Coastguard Worker }
517*da0073e9SAndroid Build Coastguard Worker 
518*da0073e9SAndroid Build Coastguard Worker } // namespace torch
519