xref: /aosp_15_r20/external/pytorch/torch/csrc/PyInterpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/PythonFallbackKernel.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/PythonOpRegistrationTrampoline.h>
3*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/PyInterpreter.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/THP.h>
5*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/generated/VariableType.h>
6*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/python_arg_parser.h>
7*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/python_dispatch.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #include <string>
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker using namespace torch;
12*da0073e9SAndroid Build Coastguard Worker using namespace at;
13*da0073e9SAndroid Build Coastguard Worker using namespace c10;
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker namespace torch::detail {
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker namespace {
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker // NB: This is a macro and not a template function (like it was before)
20*da0073e9SAndroid Build Coastguard Worker // because passing in constexpr char* as template argument breaks some
21*da0073e9SAndroid Build Coastguard Worker // versions of MSVC that are being used internally at Meta.
22*da0073e9SAndroid Build Coastguard Worker // MSVC 14.16.27023 (vs2017_15.9)
23*da0073e9SAndroid Build Coastguard Worker #define CONCRETE_GPU_TRACE(device_type, func_name, ...)                       \
24*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;                                    \
25*da0073e9SAndroid Build Coastguard Worker   if (Py_IsInitialized()) {                                                   \
26*da0073e9SAndroid Build Coastguard Worker     pybind11::gil_scoped_acquire gil;                                         \
27*da0073e9SAndroid Build Coastguard Worker     try {                                                                     \
28*da0073e9SAndroid Build Coastguard Worker       /* Masquerade hip as cuda because hip uses `torch.cuda` module. */      \
29*da0073e9SAndroid Build Coastguard Worker       if (device_type == at::kHIP) {                                          \
30*da0073e9SAndroid Build Coastguard Worker         device_type = at::kCUDA;                                              \
31*da0073e9SAndroid Build Coastguard Worker       }                                                                       \
32*da0073e9SAndroid Build Coastguard Worker       std::string module_name = "torch." + DeviceTypeName(device_type, true); \
33*da0073e9SAndroid Build Coastguard Worker       py::module mod = py::module::import(module_name.c_str());               \
34*da0073e9SAndroid Build Coastguard Worker       py::object hook =                                                       \
35*da0073e9SAndroid Build Coastguard Worker           mod.attr("_gpu_trace").attr(func_name).attr("fire_callbacks");      \
36*da0073e9SAndroid Build Coastguard Worker       hook(__VA_ARGS__);                                                      \
37*da0073e9SAndroid Build Coastguard Worker     } catch (const std::exception& e) {                                       \
38*da0073e9SAndroid Build Coastguard Worker       LOG(ERROR) << device_type                                               \
39*da0073e9SAndroid Build Coastguard Worker                  << " trace hook execution failed: " << e.what();             \
40*da0073e9SAndroid Build Coastguard Worker     }                                                                         \
41*da0073e9SAndroid Build Coastguard Worker   }
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker struct ConcretePyInterpreterVTable final
44*da0073e9SAndroid Build Coastguard Worker     : public c10::impl::PyInterpreterVTable {
45*da0073e9SAndroid Build Coastguard Worker   std::string name() const override;
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker   void incref(PyObject* pyobj) const override;
48*da0073e9SAndroid Build Coastguard Worker   void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
49*da0073e9SAndroid Build Coastguard Worker 
50*da0073e9SAndroid Build Coastguard Worker   // TODO: Need to make this work for StorageImpl too. I imagine I'll want to
51*da0073e9SAndroid Build Coastguard Worker   // operate upon a PyObjectSlot rather than a TensorImpl
52*da0073e9SAndroid Build Coastguard Worker   c10::intrusive_ptr<c10::TensorImpl> detach(
53*da0073e9SAndroid Build Coastguard Worker       const c10::TensorImpl* self) const override;
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
56*da0073e9SAndroid Build Coastguard Worker       const override;
57*da0073e9SAndroid Build Coastguard Worker   void reportErrorCallback(PyObject* callback, DispatchKey key) const override;
58*da0073e9SAndroid Build Coastguard Worker   void python_dispatcher(
59*da0073e9SAndroid Build Coastguard Worker       const c10::OperatorHandle& op,
60*da0073e9SAndroid Build Coastguard Worker       c10::DispatchKeySet,
61*da0073e9SAndroid Build Coastguard Worker       torch::jit::Stack* stack) const override;
62*da0073e9SAndroid Build Coastguard Worker   // NB: this is defined in python_dispatch.cpp
python_op_registration_trampolinetorch::detail::__anon56d922760111::ConcretePyInterpreterVTable63*da0073e9SAndroid Build Coastguard Worker   void python_op_registration_trampoline(
64*da0073e9SAndroid Build Coastguard Worker       const c10::OperatorHandle& op,
65*da0073e9SAndroid Build Coastguard Worker       c10::DispatchKey key,
66*da0073e9SAndroid Build Coastguard Worker       c10::DispatchKeySet keyset,
67*da0073e9SAndroid Build Coastguard Worker       torch::jit::Stack* stack,
68*da0073e9SAndroid Build Coastguard Worker       bool with_keyset,
69*da0073e9SAndroid Build Coastguard Worker       bool with_op) const override {
70*da0073e9SAndroid Build Coastguard Worker     torch::impl::dispatch::python_op_registration_trampoline_impl(
71*da0073e9SAndroid Build Coastguard Worker         op, key, keyset, stack, with_keyset, with_op);
72*da0073e9SAndroid Build Coastguard Worker   }
throw_abstract_impl_not_imported_errortorch::detail::__anon56d922760111::ConcretePyInterpreterVTable73*da0073e9SAndroid Build Coastguard Worker   void throw_abstract_impl_not_imported_error(
74*da0073e9SAndroid Build Coastguard Worker       std::string opname,
75*da0073e9SAndroid Build Coastguard Worker       const char* pymodule,
76*da0073e9SAndroid Build Coastguard Worker       const char* context) const override {
77*da0073e9SAndroid Build Coastguard Worker     py::gil_scoped_acquire gil;
78*da0073e9SAndroid Build Coastguard Worker     pybind11::module::import("torch._utils_internal")
79*da0073e9SAndroid Build Coastguard Worker         .attr("throw_abstract_impl_not_imported_error")(
80*da0073e9SAndroid Build Coastguard Worker             opname, pymodule, context);
81*da0073e9SAndroid Build Coastguard Worker   }
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
84*da0073e9SAndroid Build Coastguard Worker       const override;
85*da0073e9SAndroid Build Coastguard Worker   bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
86*da0073e9SAndroid Build Coastguard Worker       const override;
87*da0073e9SAndroid Build Coastguard Worker   bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
88*da0073e9SAndroid Build Coastguard Worker   c10::Device device(const c10::TensorImpl* self) const override;
89*da0073e9SAndroid Build Coastguard Worker   int64_t dim(const c10::TensorImpl* self) const override;
90*da0073e9SAndroid Build Coastguard Worker   c10::IntArrayRef strides(const c10::TensorImpl* self) const override;
91*da0073e9SAndroid Build Coastguard Worker   c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
92*da0073e9SAndroid Build Coastguard Worker   c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
93*da0073e9SAndroid Build Coastguard Worker   c10::Layout layout(const c10::TensorImpl* self) const override;
94*da0073e9SAndroid Build Coastguard Worker   int64_t numel(const c10::TensorImpl* self) const override;
95*da0073e9SAndroid Build Coastguard Worker   c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
96*da0073e9SAndroid Build Coastguard Worker   c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
97*da0073e9SAndroid Build Coastguard Worker   c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
98*da0073e9SAndroid Build Coastguard Worker 
trace_gpu_event_creationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable99*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
100*da0073e9SAndroid Build Coastguard Worker       const override {
101*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
102*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_event_deletiontorch::detail::__anon56d922760111::ConcretePyInterpreterVTable103*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
104*da0073e9SAndroid Build Coastguard Worker       const override {
105*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
106*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_event_recordtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable107*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_event_record(
108*da0073e9SAndroid Build Coastguard Worker       at::DeviceType device_type,
109*da0073e9SAndroid Build Coastguard Worker       uintptr_t event,
110*da0073e9SAndroid Build Coastguard Worker       uintptr_t stream) const override {
111*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
112*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_event_waittorch::detail::__anon56d922760111::ConcretePyInterpreterVTable113*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_event_wait(
114*da0073e9SAndroid Build Coastguard Worker       at::DeviceType device_type,
115*da0073e9SAndroid Build Coastguard Worker       uintptr_t event,
116*da0073e9SAndroid Build Coastguard Worker       uintptr_t stream) const override {
117*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
118*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_memory_allocationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable119*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
120*da0073e9SAndroid Build Coastguard Worker       const override {
121*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
122*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_memory_deallocationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable123*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
124*da0073e9SAndroid Build Coastguard Worker       const override {
125*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
126*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_stream_creationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable127*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
128*da0073e9SAndroid Build Coastguard Worker       const override {
129*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
130*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_device_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable131*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_device_synchronization(
132*da0073e9SAndroid Build Coastguard Worker       at::DeviceType device_type) const override {
133*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
134*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_stream_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable135*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_stream_synchronization(
136*da0073e9SAndroid Build Coastguard Worker       at::DeviceType device_type,
137*da0073e9SAndroid Build Coastguard Worker       uintptr_t stream) const override {
138*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
139*da0073e9SAndroid Build Coastguard Worker   }
trace_gpu_event_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable140*da0073e9SAndroid Build Coastguard Worker   void trace_gpu_event_synchronization(
141*da0073e9SAndroid Build Coastguard Worker       at::DeviceType device_type,
142*da0073e9SAndroid Build Coastguard Worker       uintptr_t event) const override {
143*da0073e9SAndroid Build Coastguard Worker     CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
144*da0073e9SAndroid Build Coastguard Worker   }
145*da0073e9SAndroid Build Coastguard Worker 
146*da0073e9SAndroid Build Coastguard Worker   void reset_backward_hooks(const c10::TensorImpl* self) const override;
147*da0073e9SAndroid Build Coastguard Worker 
instancetorch::detail::__anon56d922760111::ConcretePyInterpreterVTable148*da0073e9SAndroid Build Coastguard Worker   static ConcretePyInterpreterVTable* instance() {
149*da0073e9SAndroid Build Coastguard Worker     static ConcretePyInterpreterVTable s;
150*da0073e9SAndroid Build Coastguard Worker     return &s;
151*da0073e9SAndroid Build Coastguard Worker   }
152*da0073e9SAndroid Build Coastguard Worker };
153*da0073e9SAndroid Build Coastguard Worker 
154*da0073e9SAndroid Build Coastguard Worker class PyInterpreterHolder {
155*da0073e9SAndroid Build Coastguard Worker  public:
PyInterpreterHolder()156*da0073e9SAndroid Build Coastguard Worker   PyInterpreterHolder()
157*da0073e9SAndroid Build Coastguard Worker       : impl_(new c10::impl::PyInterpreter(
158*da0073e9SAndroid Build Coastguard Worker             ConcretePyInterpreterVTable::instance())),
159*da0073e9SAndroid Build Coastguard Worker         is_main_interpreter_(
160*da0073e9SAndroid Build Coastguard Worker             at::impl::PythonOpRegistrationTrampoline::registerInterpreter(
161*da0073e9SAndroid Build Coastguard Worker                 impl_)) {}
162*da0073e9SAndroid Build Coastguard Worker   // NB: intentionally leaks the PyInterpreter, as there may still be
163*da0073e9SAndroid Build Coastguard Worker   // references to it that are live, living in objects that aren't being
164*da0073e9SAndroid Build Coastguard Worker   // destructed while Python is being cleaned up.
~PyInterpreterHolder()165*da0073e9SAndroid Build Coastguard Worker   ~PyInterpreterHolder() {
166*da0073e9SAndroid Build Coastguard Worker     impl_->disarm();
167*da0073e9SAndroid Build Coastguard Worker   }
get() const168*da0073e9SAndroid Build Coastguard Worker   c10::impl::PyInterpreter* get() const noexcept {
169*da0073e9SAndroid Build Coastguard Worker     return impl_;
170*da0073e9SAndroid Build Coastguard Worker   }
is_main_interpreter() const171*da0073e9SAndroid Build Coastguard Worker   bool is_main_interpreter() const noexcept {
172*da0073e9SAndroid Build Coastguard Worker     return is_main_interpreter_;
173*da0073e9SAndroid Build Coastguard Worker   }
174*da0073e9SAndroid Build Coastguard Worker 
175*da0073e9SAndroid Build Coastguard Worker  private:
176*da0073e9SAndroid Build Coastguard Worker   c10::impl::PyInterpreter* impl_;
177*da0073e9SAndroid Build Coastguard Worker   bool is_main_interpreter_;
178*da0073e9SAndroid Build Coastguard Worker };
179*da0073e9SAndroid Build Coastguard Worker 
torchDispatchFromTensorImpl(const c10::TensorImpl * self,const char * func_name,PyObject * torch_api_function,const char * module_name,c10::SmallVector<py::object,1> extra_args={})180*da0073e9SAndroid Build Coastguard Worker py::object torchDispatchFromTensorImpl(
181*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self,
182*da0073e9SAndroid Build Coastguard Worker     const char* func_name,
183*da0073e9SAndroid Build Coastguard Worker     PyObject* torch_api_function,
184*da0073e9SAndroid Build Coastguard Worker     const char* module_name,
185*da0073e9SAndroid Build Coastguard Worker     // WARNING: MUST NOT BE TENSOR ARGS
186*da0073e9SAndroid Build Coastguard Worker     c10::SmallVector<py::object, 1> extra_args = {}) {
187*da0073e9SAndroid Build Coastguard Worker   if (torch_api_function == nullptr) {
188*da0073e9SAndroid Build Coastguard Worker     throw python_error();
189*da0073e9SAndroid Build Coastguard Worker   }
190*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
191*da0073e9SAndroid Build Coastguard Worker       PyGILState_Check(),
192*da0073e9SAndroid Build Coastguard Worker       "GIL must be held before you call parseIValuesToPyArgsKwargs");
193*da0073e9SAndroid Build Coastguard Worker 
194*da0073e9SAndroid Build Coastguard Worker   std::vector<PyObject*> overloaded_args;
195*da0073e9SAndroid Build Coastguard Worker   // TODO: there should be a shorter way to spell this
196*da0073e9SAndroid Build Coastguard Worker   // TODO: fix the constness of target
197*da0073e9SAndroid Build Coastguard Worker   at::Tensor self_t = at::Tensor(
198*da0073e9SAndroid Build Coastguard Worker       c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
199*da0073e9SAndroid Build Coastguard Worker           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
200*da0073e9SAndroid Build Coastguard Worker       unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
201*da0073e9SAndroid Build Coastguard Worker   auto self_p =
202*da0073e9SAndroid Build Coastguard Worker       py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
203*da0073e9SAndroid Build Coastguard Worker   // NB: this may not be a python tensor if you got here from a mode!
204*da0073e9SAndroid Build Coastguard Worker   // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
205*da0073e9SAndroid Build Coastguard Worker   append_overloaded_tensor(&overloaded_args, self_p.ptr());
206*da0073e9SAndroid Build Coastguard Worker   auto args = py::reinterpret_steal<py::object>(
207*da0073e9SAndroid Build Coastguard Worker       PyTuple_New(static_cast<Py_ssize_t>(1 + extra_args.size())));
208*da0073e9SAndroid Build Coastguard Worker   PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
209*da0073e9SAndroid Build Coastguard Worker   int64_t i = 1;
210*da0073e9SAndroid Build Coastguard Worker   for (auto& a : extra_args) {
211*da0073e9SAndroid Build Coastguard Worker     if (a.ptr() == nullptr)
212*da0073e9SAndroid Build Coastguard Worker       throw python_error();
213*da0073e9SAndroid Build Coastguard Worker     PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
214*da0073e9SAndroid Build Coastguard Worker     i++;
215*da0073e9SAndroid Build Coastguard Worker   }
216*da0073e9SAndroid Build Coastguard Worker 
217*da0073e9SAndroid Build Coastguard Worker   py::dict kwargs;
218*da0073e9SAndroid Build Coastguard Worker 
219*da0073e9SAndroid Build Coastguard Worker   return py::reinterpret_steal<py::object>(
220*da0073e9SAndroid Build Coastguard Worker       handle_torch_function_no_python_arg_parser(
221*da0073e9SAndroid Build Coastguard Worker           overloaded_args,
222*da0073e9SAndroid Build Coastguard Worker           args.ptr(),
223*da0073e9SAndroid Build Coastguard Worker           kwargs.ptr(),
224*da0073e9SAndroid Build Coastguard Worker           func_name,
225*da0073e9SAndroid Build Coastguard Worker           torch_api_function,
226*da0073e9SAndroid Build Coastguard Worker           module_name,
227*da0073e9SAndroid Build Coastguard Worker           TorchFunctionName::TorchDispatch));
228*da0073e9SAndroid Build Coastguard Worker }
229*da0073e9SAndroid Build Coastguard Worker 
230*da0073e9SAndroid Build Coastguard Worker // NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
231*da0073e9SAndroid Build Coastguard Worker // Before calling PyInterpreter::decref, we must statically know if the
232*da0073e9SAndroid Build Coastguard Worker // pyobj has a PyObjectSlot or not.
233*da0073e9SAndroid Build Coastguard Worker // - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
234*da0073e9SAndroid Build Coastguard Worker // - If it does not have a PyObjectSlot, we can freely decref
235*da0073e9SAndroid Build Coastguard Worker // One alternative to this is using PyObject_IsInstance
236*da0073e9SAndroid Build Coastguard Worker // to get at this information. However, we don't want to risk an incorrect
237*da0073e9SAndroid Build Coastguard Worker // `__instancecheck__` changing the semantics here.
decref(PyObject * pyobj,bool has_pyobj_slot) const238*da0073e9SAndroid Build Coastguard Worker void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
239*da0073e9SAndroid Build Coastguard Worker     const {
240*da0073e9SAndroid Build Coastguard Worker   // Leak the pyobj if not initialized.  This can happen if we are running
241*da0073e9SAndroid Build Coastguard Worker   // exit handlers that are destructing tensors with residual (owned)
242*da0073e9SAndroid Build Coastguard Worker   // PyObjects stored in them.
243*da0073e9SAndroid Build Coastguard Worker   if (!Py_IsInitialized())
244*da0073e9SAndroid Build Coastguard Worker     return;
245*da0073e9SAndroid Build Coastguard Worker 
246*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
247*da0073e9SAndroid Build Coastguard Worker   // Two possibilities:
248*da0073e9SAndroid Build Coastguard Worker   // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
249*da0073e9SAndroid Build Coastguard Worker   // Storage. Then we must be careful about PyObject resurrection (see
250*da0073e9SAndroid Build Coastguard Worker   // THPVariable_clear).
251*da0073e9SAndroid Build Coastguard Worker   // 2. We are decref-ing some other Python object. We don't do
252*da0073e9SAndroid Build Coastguard Worker   // PyObject resurrection on non-Tensors, so we just carry on as usual
253*da0073e9SAndroid Build Coastguard Worker   if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
254*da0073e9SAndroid Build Coastguard Worker     if (THPVariable_Check(pyobj)) {
255*da0073e9SAndroid Build Coastguard Worker       // It's still alive!  This can happen if a weak ref resurrected
256*da0073e9SAndroid Build Coastguard Worker       // the PyObject without flipping ownership.  At this point it is
257*da0073e9SAndroid Build Coastguard Worker       // too late to rescue the object, so just stub out the PyObject
258*da0073e9SAndroid Build Coastguard Worker       // so that it fails on subsequent uses.  Don't raise an error here;
259*da0073e9SAndroid Build Coastguard Worker       // you're probably in a destructor.
260*da0073e9SAndroid Build Coastguard Worker       TORCH_WARN(
261*da0073e9SAndroid Build Coastguard Worker           "Deallocating Tensor that still has live PyObject references.  "
262*da0073e9SAndroid Build Coastguard Worker           "This probably happened because you took out a weak reference to "
263*da0073e9SAndroid Build Coastguard Worker           "Tensor and didn't call _fix_weakref() after dereferencing it.  "
264*da0073e9SAndroid Build Coastguard Worker           "Subsequent accesses to this tensor via the PyObject will now fail.");
265*da0073e9SAndroid Build Coastguard Worker       ((THPVariable*)pyobj)->cdata =
266*da0073e9SAndroid Build Coastguard Worker           c10::MaybeOwned<torch::autograd::Variable>();
267*da0073e9SAndroid Build Coastguard Worker     } else if (THPStorage_Check(pyobj)) {
268*da0073e9SAndroid Build Coastguard Worker       TORCH_WARN(
269*da0073e9SAndroid Build Coastguard Worker           "Deallocating UntypedStorage that still has live PyObject references.  "
270*da0073e9SAndroid Build Coastguard Worker           "This probably happened because you took out a weak reference to "
271*da0073e9SAndroid Build Coastguard Worker           "UntypedStorage and didn't call _fix_weakref() after dereferencing it.  "
272*da0073e9SAndroid Build Coastguard Worker           "Subsequent accesses to this storage via the PyObject will now fail.");
273*da0073e9SAndroid Build Coastguard Worker       ((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
274*da0073e9SAndroid Build Coastguard Worker     }
275*da0073e9SAndroid Build Coastguard Worker   }
276*da0073e9SAndroid Build Coastguard Worker   Py_DECREF(pyobj);
277*da0073e9SAndroid Build Coastguard Worker };
278*da0073e9SAndroid Build Coastguard Worker 
incref(PyObject * pyobj) const279*da0073e9SAndroid Build Coastguard Worker void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
280*da0073e9SAndroid Build Coastguard Worker   if (!Py_IsInitialized())
281*da0073e9SAndroid Build Coastguard Worker     return;
282*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
283*da0073e9SAndroid Build Coastguard Worker   Py_INCREF(pyobj);
284*da0073e9SAndroid Build Coastguard Worker };
285*da0073e9SAndroid Build Coastguard Worker 
isPythonTensor(const at::Tensor & tensor)286*da0073e9SAndroid Build Coastguard Worker bool isPythonTensor(const at::Tensor& tensor) {
287*da0073e9SAndroid Build Coastguard Worker   return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
288*da0073e9SAndroid Build Coastguard Worker }
289*da0073e9SAndroid Build Coastguard Worker 
reportErrorCallback(PyObject * callback,DispatchKey key) const290*da0073e9SAndroid Build Coastguard Worker void ConcretePyInterpreterVTable::reportErrorCallback(
291*da0073e9SAndroid Build Coastguard Worker     PyObject* callback,
292*da0073e9SAndroid Build Coastguard Worker     DispatchKey key) const {
293*da0073e9SAndroid Build Coastguard Worker   py::gil_scoped_acquire g;
294*da0073e9SAndroid Build Coastguard Worker   auto func = py::reinterpret_borrow<py::object>(callback);
295*da0073e9SAndroid Build Coastguard Worker   // Not all DispatchKeys are pybind'ed into Python and we do not have infra
296*da0073e9SAndroid Build Coastguard Worker   // to ensure this, so just pass a string back to Python.
297*da0073e9SAndroid Build Coastguard Worker   func(c10::toString(key));
298*da0073e9SAndroid Build Coastguard Worker }
299*da0073e9SAndroid Build Coastguard Worker 
dispatch(const c10::OperatorHandle & op,torch::jit::Stack * stack) const300*da0073e9SAndroid Build Coastguard Worker void ConcretePyInterpreterVTable::dispatch(
301*da0073e9SAndroid Build Coastguard Worker     const c10::OperatorHandle& op,
302*da0073e9SAndroid Build Coastguard Worker     torch::jit::Stack* stack) const {
303*da0073e9SAndroid Build Coastguard Worker   const auto& schema = op.schema();
304*da0073e9SAndroid Build Coastguard Worker   const auto num_arguments = schema.arguments().size();
305*da0073e9SAndroid Build Coastguard Worker   auto arguments = torch::jit::pop(*stack, num_arguments);
306*da0073e9SAndroid Build Coastguard Worker 
307*da0073e9SAndroid Build Coastguard Worker   // The plan: convert all the arguments back into PyObjects,
308*da0073e9SAndroid Build Coastguard Worker   // extracting out the tensor handles, then call
309*da0073e9SAndroid Build Coastguard Worker   // handle_torch_function_no_python_arg_parser
310*da0073e9SAndroid Build Coastguard Worker   // NB: at the point arguments are pushed to the stack, ALL defaults
311*da0073e9SAndroid Build Coastguard Worker   // are already present
312*da0073e9SAndroid Build Coastguard Worker 
313*da0073e9SAndroid Build Coastguard Worker   py::gil_scoped_acquire g;
314*da0073e9SAndroid Build Coastguard Worker 
315*da0073e9SAndroid Build Coastguard Worker   std::vector<PyObject*> overloaded_args;
316*da0073e9SAndroid Build Coastguard Worker   py::handle torch_api_function_overload = getTorchApiFunction(op);
317*da0073e9SAndroid Build Coastguard Worker 
318*da0073e9SAndroid Build Coastguard Worker   // Find overloaded tensors
319*da0073e9SAndroid Build Coastguard Worker   for (const auto idx : c10::irange(arguments.size())) {
320*da0073e9SAndroid Build Coastguard Worker     const auto& ivalue = arguments[idx];
321*da0073e9SAndroid Build Coastguard Worker     if (ivalue.isTensor()) {
322*da0073e9SAndroid Build Coastguard Worker       const auto& tensor = ivalue.toTensor();
323*da0073e9SAndroid Build Coastguard Worker       if (isPythonTensor(tensor)) {
324*da0073e9SAndroid Build Coastguard Worker         append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
325*da0073e9SAndroid Build Coastguard Worker       }
326*da0073e9SAndroid Build Coastguard Worker     } else if (ivalue.isList()) {
327*da0073e9SAndroid Build Coastguard Worker       const auto& list = ivalue.toListRef();
328*da0073e9SAndroid Build Coastguard Worker       for (const auto jdx : c10::irange(list.size())) {
329*da0073e9SAndroid Build Coastguard Worker         const auto& nv = list[jdx];
330*da0073e9SAndroid Build Coastguard Worker         if (nv.isTensor()) {
331*da0073e9SAndroid Build Coastguard Worker           const auto& tensor = nv.toTensor();
332*da0073e9SAndroid Build Coastguard Worker           if (isPythonTensor(tensor)) {
333*da0073e9SAndroid Build Coastguard Worker             append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
334*da0073e9SAndroid Build Coastguard Worker           }
335*da0073e9SAndroid Build Coastguard Worker         }
336*da0073e9SAndroid Build Coastguard Worker       }
337*da0073e9SAndroid Build Coastguard Worker     }
338*da0073e9SAndroid Build Coastguard Worker   }
339*da0073e9SAndroid Build Coastguard Worker 
340*da0073e9SAndroid Build Coastguard Worker   auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
341*da0073e9SAndroid Build Coastguard Worker   auto args = std::move(args_kwargs.first);
342*da0073e9SAndroid Build Coastguard Worker   auto kwargs = std::move(args_kwargs.second);
343*da0073e9SAndroid Build Coastguard Worker 
344*da0073e9SAndroid Build Coastguard Worker   PyObject* obj = handle_torch_function_no_python_arg_parser(
345*da0073e9SAndroid Build Coastguard Worker       overloaded_args,
346*da0073e9SAndroid Build Coastguard Worker       args.ptr(),
347*da0073e9SAndroid Build Coastguard Worker       kwargs.ptr(),
348*da0073e9SAndroid Build Coastguard Worker       nullptr,
349*da0073e9SAndroid Build Coastguard Worker       torch_api_function_overload.ptr(),
350*da0073e9SAndroid Build Coastguard Worker       nullptr,
351*da0073e9SAndroid Build Coastguard Worker       TorchFunctionName::TorchDispatch);
352*da0073e9SAndroid Build Coastguard Worker   pushPyOutToStack(
353*da0073e9SAndroid Build Coastguard Worker       op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
354*da0073e9SAndroid Build Coastguard Worker }
355*da0073e9SAndroid Build Coastguard Worker 
python_dispatcher(const c10::OperatorHandle & op,c10::DispatchKeySet ks,torch::jit::Stack * stack) const356*da0073e9SAndroid Build Coastguard Worker void ConcretePyInterpreterVTable::python_dispatcher(
357*da0073e9SAndroid Build Coastguard Worker     const c10::OperatorHandle& op,
358*da0073e9SAndroid Build Coastguard Worker     c10::DispatchKeySet ks,
359*da0073e9SAndroid Build Coastguard Worker     torch::jit::Stack* stack) const {
360*da0073e9SAndroid Build Coastguard Worker   py::gil_scoped_acquire g;
361*da0073e9SAndroid Build Coastguard Worker   py::handle torch_api_function_overload = getTorchApiFunction(op);
362*da0073e9SAndroid Build Coastguard Worker   // TODO: if necessary, can optimize to cache the cache lookup
363*da0073e9SAndroid Build Coastguard Worker   // TODO: if necessary, can optimize OpOverload to have slots
364*da0073e9SAndroid Build Coastguard Worker   auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache"));
365*da0073e9SAndroid Build Coastguard Worker   if (cache.ptr() == nullptr) {
366*da0073e9SAndroid Build Coastguard Worker     throw python_error();
367*da0073e9SAndroid Build Coastguard Worker   }
368*da0073e9SAndroid Build Coastguard Worker 
369*da0073e9SAndroid Build Coastguard Worker   c10::DispatchKey k = ks.highestPriorityTypeId();
370*da0073e9SAndroid Build Coastguard Worker   // TODO: allow this to be non-owning
371*da0073e9SAndroid Build Coastguard Worker   auto handler = py::reinterpret_borrow<py::object>(
372*da0073e9SAndroid Build Coastguard Worker       PyDict_GetItem(cache.ptr(), py::cast(k).ptr()));
373*da0073e9SAndroid Build Coastguard Worker   if (handler.ptr() == nullptr) {
374*da0073e9SAndroid Build Coastguard Worker     // Slow path
375*da0073e9SAndroid Build Coastguard Worker     handler = torch_api_function_overload.attr("_get_dispatch")(k);
376*da0073e9SAndroid Build Coastguard Worker   }
377*da0073e9SAndroid Build Coastguard Worker   if (py::isinstance<c10::DispatchKey>(handler)) {
378*da0073e9SAndroid Build Coastguard Worker     // NB: not redispatch, as that will permanently remove the python
379*da0073e9SAndroid Build Coastguard Worker     // dispatcher for subsequent redispatches
380*da0073e9SAndroid Build Coastguard Worker     op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack);
381*da0073e9SAndroid Build Coastguard Worker     return;
382*da0073e9SAndroid Build Coastguard Worker   }
383*da0073e9SAndroid Build Coastguard Worker 
384*da0073e9SAndroid Build Coastguard Worker   const auto& schema = op.schema();
385*da0073e9SAndroid Build Coastguard Worker   const auto num_arguments = schema.arguments().size();
386*da0073e9SAndroid Build Coastguard Worker   auto arguments = torch::jit::pop(*stack, num_arguments);
387*da0073e9SAndroid Build Coastguard Worker 
388*da0073e9SAndroid Build Coastguard Worker   auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
389*da0073e9SAndroid Build Coastguard Worker   auto args = std::move(args_kwargs.first);
390*da0073e9SAndroid Build Coastguard Worker   auto kwargs = std::move(args_kwargs.second);
391*da0073e9SAndroid Build Coastguard Worker 
392*da0073e9SAndroid Build Coastguard Worker   py::object obj = py::reinterpret_steal<py::object>(
393*da0073e9SAndroid Build Coastguard Worker       PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
394*da0073e9SAndroid Build Coastguard Worker 
395*da0073e9SAndroid Build Coastguard Worker   if (obj.ptr() == nullptr) {
396*da0073e9SAndroid Build Coastguard Worker     throw python_error();
397*da0073e9SAndroid Build Coastguard Worker   }
398*da0073e9SAndroid Build Coastguard Worker 
399*da0073e9SAndroid Build Coastguard Worker   pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
400*da0073e9SAndroid Build Coastguard Worker }
401*da0073e9SAndroid Build Coastguard Worker 
detach(const c10::TensorImpl * self) const402*da0073e9SAndroid Build Coastguard Worker c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach(
403*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
404*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
405*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
406*da0073e9SAndroid Build Coastguard Worker 
407*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
408*da0073e9SAndroid Build Coastguard Worker       self,
409*da0073e9SAndroid Build Coastguard Worker       "detach",
410*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
411*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
412*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
413*da0073e9SAndroid Build Coastguard Worker           .attr("detach")
414*da0073e9SAndroid Build Coastguard Worker           .attr("default")
415*da0073e9SAndroid Build Coastguard Worker           .ptr(),
416*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
417*da0073e9SAndroid Build Coastguard Worker 
418*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
419*da0073e9SAndroid Build Coastguard Worker       THPVariable_Check(out.ptr()),
420*da0073e9SAndroid Build Coastguard Worker       "detach returned invalid type ",
421*da0073e9SAndroid Build Coastguard Worker       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
422*da0073e9SAndroid Build Coastguard Worker       ", expected Tensor");
423*da0073e9SAndroid Build Coastguard Worker   const at::Tensor& res_t = THPVariable_Unpack(out.ptr());
424*da0073e9SAndroid Build Coastguard Worker   return res_t.getIntrusivePtr();
425*da0073e9SAndroid Build Coastguard Worker }
426*da0073e9SAndroid Build Coastguard Worker 
is_contiguous(const c10::TensorImpl * self,at::MemoryFormat memory_format) const427*da0073e9SAndroid Build Coastguard Worker bool ConcretePyInterpreterVTable::is_contiguous(
428*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self,
429*da0073e9SAndroid Build Coastguard Worker     at::MemoryFormat memory_format) const {
430*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
431*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
432*da0073e9SAndroid Build Coastguard Worker 
433*da0073e9SAndroid Build Coastguard Worker   py::object out;
434*da0073e9SAndroid Build Coastguard Worker   if (memory_format == at::MemoryFormat::Contiguous) {
435*da0073e9SAndroid Build Coastguard Worker     // For backwards compatibility
436*da0073e9SAndroid Build Coastguard Worker     out = torchDispatchFromTensorImpl(
437*da0073e9SAndroid Build Coastguard Worker         self,
438*da0073e9SAndroid Build Coastguard Worker         "is_contiguous",
439*da0073e9SAndroid Build Coastguard Worker         py::module::import("torch")
440*da0073e9SAndroid Build Coastguard Worker             .attr("ops")
441*da0073e9SAndroid Build Coastguard Worker             .attr("aten")
442*da0073e9SAndroid Build Coastguard Worker             .attr("is_contiguous")
443*da0073e9SAndroid Build Coastguard Worker             .attr("default")
444*da0073e9SAndroid Build Coastguard Worker             .ptr(),
445*da0073e9SAndroid Build Coastguard Worker         "torch.ops.aten");
446*da0073e9SAndroid Build Coastguard Worker   } else {
447*da0073e9SAndroid Build Coastguard Worker     out = torchDispatchFromTensorImpl(
448*da0073e9SAndroid Build Coastguard Worker         self,
449*da0073e9SAndroid Build Coastguard Worker         "is_contiguous",
450*da0073e9SAndroid Build Coastguard Worker         py::module::import("torch")
451*da0073e9SAndroid Build Coastguard Worker             .attr("ops")
452*da0073e9SAndroid Build Coastguard Worker             .attr("aten")
453*da0073e9SAndroid Build Coastguard Worker             .attr("is_contiguous")
454*da0073e9SAndroid Build Coastguard Worker             .attr("memory_format")
455*da0073e9SAndroid Build Coastguard Worker             .ptr(),
456*da0073e9SAndroid Build Coastguard Worker         "torch.ops.aten",
457*da0073e9SAndroid Build Coastguard Worker         {py::cast(memory_format)});
458*da0073e9SAndroid Build Coastguard Worker   }
459*da0073e9SAndroid Build Coastguard Worker 
460*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
461*da0073e9SAndroid Build Coastguard Worker     return self->is_contiguous_default(memory_format);
462*da0073e9SAndroid Build Coastguard Worker   }
463*da0073e9SAndroid Build Coastguard Worker 
464*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
465*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(out.ptr()),
466*da0073e9SAndroid Build Coastguard Worker       "is_contiguous returned invalid type ",
467*da0073e9SAndroid Build Coastguard Worker       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
468*da0073e9SAndroid Build Coastguard Worker       ", expected bool");
469*da0073e9SAndroid Build Coastguard Worker 
470*da0073e9SAndroid Build Coastguard Worker   return PyObject_IsTrue(out.ptr());
471*da0073e9SAndroid Build Coastguard Worker }
472*da0073e9SAndroid Build Coastguard Worker 
is_strides_like(const c10::TensorImpl * self,at::MemoryFormat memory_format) const473*da0073e9SAndroid Build Coastguard Worker bool ConcretePyInterpreterVTable::is_strides_like(
474*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self,
475*da0073e9SAndroid Build Coastguard Worker     at::MemoryFormat memory_format) const {
476*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
477*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
478*da0073e9SAndroid Build Coastguard Worker 
479*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
480*da0073e9SAndroid Build Coastguard Worker       self,
481*da0073e9SAndroid Build Coastguard Worker       "is_strides_like",
482*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
483*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
484*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
485*da0073e9SAndroid Build Coastguard Worker           // NB: intentionally suffixed with _format to avoid
486*da0073e9SAndroid Build Coastguard Worker           // triggering matches against "_like" suffix
487*da0073e9SAndroid Build Coastguard Worker           .attr("is_strides_like_format")
488*da0073e9SAndroid Build Coastguard Worker           .attr("default")
489*da0073e9SAndroid Build Coastguard Worker           .ptr(),
490*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten",
491*da0073e9SAndroid Build Coastguard Worker       {py::cast(memory_format)});
492*da0073e9SAndroid Build Coastguard Worker 
493*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
494*da0073e9SAndroid Build Coastguard Worker     return self->is_strides_like_default(memory_format);
495*da0073e9SAndroid Build Coastguard Worker   }
496*da0073e9SAndroid Build Coastguard Worker 
497*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
498*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(out.ptr()),
499*da0073e9SAndroid Build Coastguard Worker       "is_strides_like_format returned invalid type ",
500*da0073e9SAndroid Build Coastguard Worker       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
501*da0073e9SAndroid Build Coastguard Worker       ", expected bool");
502*da0073e9SAndroid Build Coastguard Worker 
503*da0073e9SAndroid Build Coastguard Worker   return PyObject_IsTrue(out.ptr());
504*da0073e9SAndroid Build Coastguard Worker }
505*da0073e9SAndroid Build Coastguard Worker 
is_non_overlapping_and_dense(const c10::TensorImpl * self) const506*da0073e9SAndroid Build Coastguard Worker bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
507*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
508*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
509*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
510*da0073e9SAndroid Build Coastguard Worker 
511*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
512*da0073e9SAndroid Build Coastguard Worker       self,
513*da0073e9SAndroid Build Coastguard Worker       "is_non_overlapping_and_dense",
514*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
515*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
516*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
517*da0073e9SAndroid Build Coastguard Worker           .attr("is_non_overlapping_and_dense")
518*da0073e9SAndroid Build Coastguard Worker           .attr("default")
519*da0073e9SAndroid Build Coastguard Worker           .ptr(),
520*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
521*da0073e9SAndroid Build Coastguard Worker 
522*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
523*da0073e9SAndroid Build Coastguard Worker     return self->is_non_overlapping_and_dense_default();
524*da0073e9SAndroid Build Coastguard Worker   }
525*da0073e9SAndroid Build Coastguard Worker 
526*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
527*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(out.ptr()),
528*da0073e9SAndroid Build Coastguard Worker       "is_non_overlapping_and_dense returned invalid type ",
529*da0073e9SAndroid Build Coastguard Worker       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
530*da0073e9SAndroid Build Coastguard Worker       ", expected bool");
531*da0073e9SAndroid Build Coastguard Worker 
532*da0073e9SAndroid Build Coastguard Worker   return PyObject_IsTrue(out.ptr());
533*da0073e9SAndroid Build Coastguard Worker }
534*da0073e9SAndroid Build Coastguard Worker 
dim(const c10::TensorImpl * self) const535*da0073e9SAndroid Build Coastguard Worker int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const {
536*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
537*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
538*da0073e9SAndroid Build Coastguard Worker 
539*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
540*da0073e9SAndroid Build Coastguard Worker       self,
541*da0073e9SAndroid Build Coastguard Worker       "dim",
542*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
543*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
544*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
545*da0073e9SAndroid Build Coastguard Worker           .attr("dim")
546*da0073e9SAndroid Build Coastguard Worker           .attr("default")
547*da0073e9SAndroid Build Coastguard Worker           .ptr(),
548*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
549*da0073e9SAndroid Build Coastguard Worker 
550*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
551*da0073e9SAndroid Build Coastguard Worker       PyLong_Check(out.ptr()),
552*da0073e9SAndroid Build Coastguard Worker       "dim returned invalid type ",
553*da0073e9SAndroid Build Coastguard Worker       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
554*da0073e9SAndroid Build Coastguard Worker       ", expected int");
555*da0073e9SAndroid Build Coastguard Worker 
556*da0073e9SAndroid Build Coastguard Worker   return THPUtils_unpackLong(out.ptr());
557*da0073e9SAndroid Build Coastguard Worker }
558*da0073e9SAndroid Build Coastguard Worker 
device(const c10::TensorImpl * self) const559*da0073e9SAndroid Build Coastguard Worker c10::Device ConcretePyInterpreterVTable::device(
560*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
561*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
562*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
563*da0073e9SAndroid Build Coastguard Worker 
564*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
565*da0073e9SAndroid Build Coastguard Worker       self,
566*da0073e9SAndroid Build Coastguard Worker       "device",
567*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
568*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
569*da0073e9SAndroid Build Coastguard Worker           .attr("prim")
570*da0073e9SAndroid Build Coastguard Worker           .attr("device")
571*da0073e9SAndroid Build Coastguard Worker           .attr("default")
572*da0073e9SAndroid Build Coastguard Worker           .ptr(),
573*da0073e9SAndroid Build Coastguard Worker       "torch.ops.prim");
574*da0073e9SAndroid Build Coastguard Worker 
575*da0073e9SAndroid Build Coastguard Worker   return toDevice(out.ptr());
576*da0073e9SAndroid Build Coastguard Worker }
577*da0073e9SAndroid Build Coastguard Worker 
set_tensor_attr_with_capsule(const c10::TensorImpl * tensor,py::capsule & capsule,const char * attr_name)578*da0073e9SAndroid Build Coastguard Worker static void set_tensor_attr_with_capsule(
579*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* tensor,
580*da0073e9SAndroid Build Coastguard Worker     py::capsule& capsule,
581*da0073e9SAndroid Build Coastguard Worker     const char* attr_name) {
582*da0073e9SAndroid Build Coastguard Worker   std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
583*da0073e9SAndroid Build Coastguard Worker       getPyInterpreter(), /*ignore_hermetic_tls=*/false);
584*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
585*da0073e9SAndroid Build Coastguard Worker       mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
586*da0073e9SAndroid Build Coastguard Worker   auto obj = mb_obj.value();
587*da0073e9SAndroid Build Coastguard Worker   py::handle(obj).attr(attr_name) = capsule;
588*da0073e9SAndroid Build Coastguard Worker }
589*da0073e9SAndroid Build Coastguard Worker 
590*da0073e9SAndroid Build Coastguard Worker // Note [Tensor Subclass custom size/stride caching strategy]
591*da0073e9SAndroid Build Coastguard Worker // Tensor subclasses can use __torch_dispatch__ to override size/stride calls.
592*da0073e9SAndroid Build Coastguard Worker // However, this presents a problem:
593*da0073e9SAndroid Build Coastguard Worker // (1) When you return a custom (maybe symbolic) size/stride
594*da0073e9SAndroid Build Coastguard Worker //     from python, we need to stash this fresh vector of ints/symints
595*da0073e9SAndroid Build Coastguard Worker //     somewhere so that it has the same lifetime as the tensor.
596*da0073e9SAndroid Build Coastguard Worker // (2) If the subclass experiences a metadata mutation,
597*da0073e9SAndroid Build Coastguard Worker //     this stashed vector is no longer valid, so we need to allocate a fresh
598*da0073e9SAndroid Build Coastguard Worker //     buffer to store the new sizes the next time someone asks for them.
599*da0073e9SAndroid Build Coastguard Worker //
600*da0073e9SAndroid Build Coastguard Worker // We handle this in the same way that `TensorImpl::sizes_default()`
601*da0073e9SAndroid Build Coastguard Worker // handles its buffer: we simply reallocate the buffer whenever
602*da0073e9SAndroid Build Coastguard Worker // the number of dimensions changes due to a resize.
603*da0073e9SAndroid Build Coastguard Worker // Notable, we do *not* reallocate the buffer if the values changed,
604*da0073e9SAndroid Build Coastguard Worker // but the number of dimensions stayed the same (e.g. `.transpose_()`).
605*da0073e9SAndroid Build Coastguard Worker template <typename T>
get_set_cached_attr(const c10::TensorImpl * tensor,const char * base_attr_name,const py::object & obj)606*da0073e9SAndroid Build Coastguard Worker static c10::ArrayRef<T> get_set_cached_attr(
607*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* tensor,
608*da0073e9SAndroid Build Coastguard Worker     const char* base_attr_name,
609*da0073e9SAndroid Build Coastguard Worker     const py::object& obj) {
610*da0073e9SAndroid Build Coastguard Worker   std::optional<PyObject*> mb_obj =
611*da0073e9SAndroid Build Coastguard Worker       tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
612*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
613*da0073e9SAndroid Build Coastguard Worker       mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
614*da0073e9SAndroid Build Coastguard Worker   auto tensor_obj = mb_obj.value();
615*da0073e9SAndroid Build Coastguard Worker   auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
616*da0073e9SAndroid Build Coastguard Worker 
617*da0073e9SAndroid Build Coastguard Worker   bool is_buffer_allocated = false;
618*da0073e9SAndroid Build Coastguard Worker   size_t curr_size = 0;
619*da0073e9SAndroid Build Coastguard Worker   if (PyObject_HasAttrString(tensor_obj, buffer_len_attr_name.c_str())) {
620*da0073e9SAndroid Build Coastguard Worker     auto len_pyobj = py::handle(tensor_obj).attr(buffer_len_attr_name.c_str());
621*da0073e9SAndroid Build Coastguard Worker     curr_size = py::cast<size_t>(len_pyobj);
622*da0073e9SAndroid Build Coastguard Worker     is_buffer_allocated = true;
623*da0073e9SAndroid Build Coastguard Worker   }
624*da0073e9SAndroid Build Coastguard Worker 
625*da0073e9SAndroid Build Coastguard Worker   size_t new_size = py::len(obj);
626*da0073e9SAndroid Build Coastguard Worker 
627*da0073e9SAndroid Build Coastguard Worker   // We do the smallvector optimization here: any time the new_size is <=5,
628*da0073e9SAndroid Build Coastguard Worker   // we always allocate our buffer to size 5, so that if the next resize
629*da0073e9SAndroid Build Coastguard Worker   // is also to <=5 elements, we don't need to reallocate.
630*da0073e9SAndroid Build Coastguard Worker   // Note: I tried removing this optimization and tripped ASAN
631*da0073e9SAndroid Build Coastguard Worker   // in a batchnorm kernel here:
632*da0073e9SAndroid Build Coastguard Worker   // https://pipelinesghubeus21.actions.githubusercontent.com/mBh68xKhi8LyM7tp3vECvYXNFvuV4gyVGgmYCteuEZP9JH92QN/_apis/pipelines/1/runs/3373307/signedlogcontent/790?urlExpires=2023-09-15T21%3A13%3A51.4327798Z&urlSigningMethod=HMACV1&urlSignature=tDeX7ZqaARVU5NNwyr5yYqqkWq3A2j4z8FFdqYwGr0Q%3D
633*da0073e9SAndroid Build Coastguard Worker   // We should fix this instead.
634*da0073e9SAndroid Build Coastguard Worker   bool needs_resize = false;
635*da0073e9SAndroid Build Coastguard Worker   // We need to resize if:
636*da0073e9SAndroid Build Coastguard Worker   // (1) we haven't allocated our buffer at all yet
637*da0073e9SAndroid Build Coastguard Worker   // (2) Our buffer size is different from the new size
638*da0073e9SAndroid Build Coastguard Worker   //     (note: we use the small vector optimization, where our buffer
639*da0073e9SAndroid Build Coastguard Worker   //     is always allocated to at least size 5, and any resizes
640*da0073e9SAndroid Build Coastguard Worker   //     within the <= 5 regime to not require a reallocation).
641*da0073e9SAndroid Build Coastguard Worker   auto is_smallvector = curr_size <= 5;
642*da0073e9SAndroid Build Coastguard Worker   needs_resize = !is_buffer_allocated || (is_smallvector && new_size > 5) ||
643*da0073e9SAndroid Build Coastguard Worker       (!is_smallvector && curr_size != new_size);
644*da0073e9SAndroid Build Coastguard Worker   if (needs_resize) {
645*da0073e9SAndroid Build Coastguard Worker     // If our current buffer is not the right size (either because we haven't
646*da0073e9SAndroid Build Coastguard Worker     // allocated it yet, or there was a metadata mutation that changed the
647*da0073e9SAndroid Build Coastguard Worker     // number of dims of the tensor), allocate a fresh buffer. Note that this
648*da0073e9SAndroid Build Coastguard Worker     // will trash the previous buffer if there already was one, invalidating any
649*da0073e9SAndroid Build Coastguard Worker     // existing SymIntArrayRef's from an old .sym_size() call.
650*da0073e9SAndroid Build Coastguard Worker     auto new_buffer_size = new_size;
651*da0073e9SAndroid Build Coastguard Worker     if (new_size <= 5) {
652*da0073e9SAndroid Build Coastguard Worker       // This is the smallvector optimization
653*da0073e9SAndroid Build Coastguard Worker       new_buffer_size = 5;
654*da0073e9SAndroid Build Coastguard Worker     }
655*da0073e9SAndroid Build Coastguard Worker     T* ptr = new T[new_buffer_size];
656*da0073e9SAndroid Build Coastguard Worker     auto capsule =
657*da0073e9SAndroid Build Coastguard Worker         py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<T*>(p); });
658*da0073e9SAndroid Build Coastguard Worker     int64_t idx = 0;
659*da0073e9SAndroid Build Coastguard Worker     for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
660*da0073e9SAndroid Build Coastguard Worker       ptr[idx] = py::cast<T>(*it);
661*da0073e9SAndroid Build Coastguard Worker     }
662*da0073e9SAndroid Build Coastguard Worker     // Set the buffer
663*da0073e9SAndroid Build Coastguard Worker     set_tensor_attr_with_capsule(tensor, capsule, base_attr_name);
664*da0073e9SAndroid Build Coastguard Worker     // Set the len buffer
665*da0073e9SAndroid Build Coastguard Worker     py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()) = new_size;
666*da0073e9SAndroid Build Coastguard Worker   } else {
667*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(PyObject_HasAttrString(tensor_obj, base_attr_name));
668*da0073e9SAndroid Build Coastguard Worker     auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
669*da0073e9SAndroid Build Coastguard Worker     void* buffer_pycapsule =
670*da0073e9SAndroid Build Coastguard Worker         PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
671*da0073e9SAndroid Build Coastguard Worker     auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
672*da0073e9SAndroid Build Coastguard Worker 
673*da0073e9SAndroid Build Coastguard Worker     // Overwrite the buffer with our new values, but only if any of them changed
674*da0073e9SAndroid Build Coastguard Worker     // (due to a metadata mutation).
675*da0073e9SAndroid Build Coastguard Worker     // This is technically not thread safe, because the update happens lazily.
676*da0073e9SAndroid Build Coastguard Worker     // The original metadata mutation call on the tensor might have been thread
677*da0073e9SAndroid Build Coastguard Worker     // safe (e.g. a .resize_() call), but we won't actually mutate the size
678*da0073e9SAndroid Build Coastguard Worker     // buffer until the first call to .sizes() which the user might not access
679*da0073e9SAndroid Build Coastguard Worker     // in a thread-safe way. For now we are not explicitly locking, but maybe we
680*da0073e9SAndroid Build Coastguard Worker     // should.
681*da0073e9SAndroid Build Coastguard Worker     int64_t idx = 0;
682*da0073e9SAndroid Build Coastguard Worker     // Quick sanity assert that our buffer size is large enough
683*da0073e9SAndroid Build Coastguard Worker     // to compare against all the elements in the new buffer.
684*da0073e9SAndroid Build Coastguard Worker     size_t curr_buffer_size = 5;
685*da0073e9SAndroid Build Coastguard Worker     if (curr_buffer_size < curr_size) {
686*da0073e9SAndroid Build Coastguard Worker       curr_buffer_size = curr_size;
687*da0073e9SAndroid Build Coastguard Worker     }
688*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(curr_buffer_size >= new_size);
689*da0073e9SAndroid Build Coastguard Worker     for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
690*da0073e9SAndroid Build Coastguard Worker       auto actual_val = py::cast<T>(*it);
691*da0073e9SAndroid Build Coastguard Worker       if constexpr (std::is_same_v<T, c10::SymInt>) {
692*da0073e9SAndroid Build Coastguard Worker         // if our SymInts are symbolic, we are *not* doing an equality check on
693*da0073e9SAndroid Build Coastguard Worker         // the symints. we just want to see if the nodes are the same. this is
694*da0073e9SAndroid Build Coastguard Worker         // because we don't want to introduce any guards here.
695*da0073e9SAndroid Build Coastguard Worker         if (!curr_buffer[idx].is_same(actual_val)) {
696*da0073e9SAndroid Build Coastguard Worker           curr_buffer[idx] = actual_val;
697*da0073e9SAndroid Build Coastguard Worker         }
698*da0073e9SAndroid Build Coastguard Worker       } else {
699*da0073e9SAndroid Build Coastguard Worker         if (curr_buffer[idx] != actual_val) {
700*da0073e9SAndroid Build Coastguard Worker           curr_buffer[idx] = actual_val;
701*da0073e9SAndroid Build Coastguard Worker         }
702*da0073e9SAndroid Build Coastguard Worker       }
703*da0073e9SAndroid Build Coastguard Worker     }
704*da0073e9SAndroid Build Coastguard Worker   }
705*da0073e9SAndroid Build Coastguard Worker 
706*da0073e9SAndroid Build Coastguard Worker   // The correct data is now stored at the buffer - read and return it.
707*da0073e9SAndroid Build Coastguard Worker   auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
708*da0073e9SAndroid Build Coastguard Worker   void* buffer_pycapsule =
709*da0073e9SAndroid Build Coastguard Worker       PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
710*da0073e9SAndroid Build Coastguard Worker   auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
711*da0073e9SAndroid Build Coastguard Worker   return c10::ArrayRef<T>(curr_buffer, new_size);
712*da0073e9SAndroid Build Coastguard Worker }
713*da0073e9SAndroid Build Coastguard Worker 
strides(const c10::TensorImpl * self) const714*da0073e9SAndroid Build Coastguard Worker c10::IntArrayRef ConcretePyInterpreterVTable::strides(
715*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
716*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
717*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
718*da0073e9SAndroid Build Coastguard Worker 
719*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
720*da0073e9SAndroid Build Coastguard Worker       self,
721*da0073e9SAndroid Build Coastguard Worker       "stride",
722*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
723*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
724*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
725*da0073e9SAndroid Build Coastguard Worker           .attr("stride")
726*da0073e9SAndroid Build Coastguard Worker           .attr("default")
727*da0073e9SAndroid Build Coastguard Worker           .ptr(),
728*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
729*da0073e9SAndroid Build Coastguard Worker 
730*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
731*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
732*da0073e9SAndroid Build Coastguard Worker         !self->has_symbolic_sizes_strides(),
733*da0073e9SAndroid Build Coastguard Worker         "Cannot call strides on a tensor with symbolic shapes/strides");
734*da0073e9SAndroid Build Coastguard Worker     return self->strides_default();
735*da0073e9SAndroid Build Coastguard Worker   }
736*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
737*da0073e9SAndroid Build Coastguard Worker       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
738*da0073e9SAndroid Build Coastguard Worker       "strides must be a list or a tuple");
739*da0073e9SAndroid Build Coastguard Worker   auto updated_strides =
740*da0073e9SAndroid Build Coastguard Worker       get_set_cached_attr<int64_t>(self, "_strides_capsule", out);
741*da0073e9SAndroid Build Coastguard Worker   return updated_strides;
742*da0073e9SAndroid Build Coastguard Worker }
743*da0073e9SAndroid Build Coastguard Worker 
sizes(const c10::TensorImpl * self) const744*da0073e9SAndroid Build Coastguard Worker c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
745*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
746*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
747*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
748*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
749*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
750*da0073e9SAndroid Build Coastguard Worker       self,
751*da0073e9SAndroid Build Coastguard Worker       "size",
752*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
753*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
754*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
755*da0073e9SAndroid Build Coastguard Worker           .attr("size")
756*da0073e9SAndroid Build Coastguard Worker           .attr("default")
757*da0073e9SAndroid Build Coastguard Worker           .ptr(),
758*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
759*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
760*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
761*da0073e9SAndroid Build Coastguard Worker         !self->has_symbolic_sizes_strides(),
762*da0073e9SAndroid Build Coastguard Worker         "Cannot call sizes on a tensor with symbolic shapes/strides");
763*da0073e9SAndroid Build Coastguard Worker     return self->sizes_default();
764*da0073e9SAndroid Build Coastguard Worker   }
765*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
766*da0073e9SAndroid Build Coastguard Worker       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
767*da0073e9SAndroid Build Coastguard Worker       "sizes must be a list or a tuple");
768*da0073e9SAndroid Build Coastguard Worker 
769*da0073e9SAndroid Build Coastguard Worker   auto updated_sizes =
770*da0073e9SAndroid Build Coastguard Worker       get_set_cached_attr<int64_t>(self, "_sizes_capsule", out);
771*da0073e9SAndroid Build Coastguard Worker   return updated_sizes;
772*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS_PYBIND
773*da0073e9SAndroid Build Coastguard Worker }
774*da0073e9SAndroid Build Coastguard Worker 
sym_sizes(const c10::TensorImpl * self) const775*da0073e9SAndroid Build Coastguard Worker c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
776*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
777*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
778*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
779*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
780*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
781*da0073e9SAndroid Build Coastguard Worker       self,
782*da0073e9SAndroid Build Coastguard Worker       "sym_size",
783*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
784*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
785*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
786*da0073e9SAndroid Build Coastguard Worker           .attr("sym_size")
787*da0073e9SAndroid Build Coastguard Worker           .attr("default")
788*da0073e9SAndroid Build Coastguard Worker           .ptr(),
789*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
790*da0073e9SAndroid Build Coastguard Worker 
791*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
792*da0073e9SAndroid Build Coastguard Worker     return self->sym_sizes_default();
793*da0073e9SAndroid Build Coastguard Worker   }
794*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
795*da0073e9SAndroid Build Coastguard Worker       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
796*da0073e9SAndroid Build Coastguard Worker       "sym_size must be a list or a tuple");
797*da0073e9SAndroid Build Coastguard Worker 
798*da0073e9SAndroid Build Coastguard Worker   // See Note [Tensor Subclass custom size/stride caching strategy]
799*da0073e9SAndroid Build Coastguard Worker   auto updated_sym_sizes =
800*da0073e9SAndroid Build Coastguard Worker       get_set_cached_attr<c10::SymInt>(self, "_sym_sizes_capsule", out);
801*da0073e9SAndroid Build Coastguard Worker   return updated_sym_sizes;
802*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS_PYBIND
803*da0073e9SAndroid Build Coastguard Worker }
804*da0073e9SAndroid Build Coastguard Worker 
layout(const c10::TensorImpl * self) const805*da0073e9SAndroid Build Coastguard Worker c10::Layout ConcretePyInterpreterVTable::layout(
806*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
807*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
808*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
809*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
810*da0073e9SAndroid Build Coastguard Worker       self,
811*da0073e9SAndroid Build Coastguard Worker       "layout",
812*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
813*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
814*da0073e9SAndroid Build Coastguard Worker           .attr("prim")
815*da0073e9SAndroid Build Coastguard Worker           .attr("layout")
816*da0073e9SAndroid Build Coastguard Worker           .attr("default")
817*da0073e9SAndroid Build Coastguard Worker           .ptr(),
818*da0073e9SAndroid Build Coastguard Worker       "torch.ops.prim");
819*da0073e9SAndroid Build Coastguard Worker 
820*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
821*da0073e9SAndroid Build Coastguard Worker       THPLayout_Check(out.ptr()) || PyLong_Check(out.ptr()),
822*da0073e9SAndroid Build Coastguard Worker       "layout returned invalid type ",
823*da0073e9SAndroid Build Coastguard Worker       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
824*da0073e9SAndroid Build Coastguard Worker       ", expected Layout");
825*da0073e9SAndroid Build Coastguard Worker 
826*da0073e9SAndroid Build Coastguard Worker   if (THPLayout_Check(out.ptr())) {
827*da0073e9SAndroid Build Coastguard Worker     return toLayout(out.ptr());
828*da0073e9SAndroid Build Coastguard Worker   } else {
829*da0073e9SAndroid Build Coastguard Worker     return c10::Layout(py::cast<int64_t>(out));
830*da0073e9SAndroid Build Coastguard Worker   }
831*da0073e9SAndroid Build Coastguard Worker }
832*da0073e9SAndroid Build Coastguard Worker 
numel(const c10::TensorImpl * self) const833*da0073e9SAndroid Build Coastguard Worker int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const {
834*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
835*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
836*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
837*da0073e9SAndroid Build Coastguard Worker       self,
838*da0073e9SAndroid Build Coastguard Worker       "numel",
839*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
840*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
841*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
842*da0073e9SAndroid Build Coastguard Worker           .attr("numel")
843*da0073e9SAndroid Build Coastguard Worker           .attr("default")
844*da0073e9SAndroid Build Coastguard Worker           .ptr(),
845*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
846*da0073e9SAndroid Build Coastguard Worker 
847*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
848*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
849*da0073e9SAndroid Build Coastguard Worker         !self->has_symbolic_sizes_strides(),
850*da0073e9SAndroid Build Coastguard Worker         "Cannot call sizes on a tensor with symbolic shapes/strides");
851*da0073e9SAndroid Build Coastguard Worker     return self->numel_default();
852*da0073e9SAndroid Build Coastguard Worker   }
853*da0073e9SAndroid Build Coastguard Worker   return py::cast<int64_t>(out);
854*da0073e9SAndroid Build Coastguard Worker }
855*da0073e9SAndroid Build Coastguard Worker 
sym_numel(const c10::TensorImpl * self) const856*da0073e9SAndroid Build Coastguard Worker c10::SymInt ConcretePyInterpreterVTable::sym_numel(
857*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
858*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
859*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
860*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
861*da0073e9SAndroid Build Coastguard Worker       self,
862*da0073e9SAndroid Build Coastguard Worker       "sym_numel",
863*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
864*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
865*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
866*da0073e9SAndroid Build Coastguard Worker           .attr("sym_numel")
867*da0073e9SAndroid Build Coastguard Worker           .attr("default")
868*da0073e9SAndroid Build Coastguard Worker           .ptr(),
869*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
870*da0073e9SAndroid Build Coastguard Worker 
871*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
872*da0073e9SAndroid Build Coastguard Worker     return self->sym_numel_default();
873*da0073e9SAndroid Build Coastguard Worker   }
874*da0073e9SAndroid Build Coastguard Worker   return torch::is_symint(out) ? out.cast<c10::SymInt>()
875*da0073e9SAndroid Build Coastguard Worker                                : c10::SymInt{py::cast<int64_t>(out)};
876*da0073e9SAndroid Build Coastguard Worker }
877*da0073e9SAndroid Build Coastguard Worker 
sym_storage_offset(const c10::TensorImpl * self) const878*da0073e9SAndroid Build Coastguard Worker c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
879*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
880*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
881*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
882*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
883*da0073e9SAndroid Build Coastguard Worker       self,
884*da0073e9SAndroid Build Coastguard Worker       "sym_storage_offset",
885*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
886*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
887*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
888*da0073e9SAndroid Build Coastguard Worker           .attr("sym_storage_offset")
889*da0073e9SAndroid Build Coastguard Worker           .attr("default")
890*da0073e9SAndroid Build Coastguard Worker           .ptr(),
891*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
892*da0073e9SAndroid Build Coastguard Worker 
893*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
894*da0073e9SAndroid Build Coastguard Worker     return self->sym_storage_offset_default();
895*da0073e9SAndroid Build Coastguard Worker   }
896*da0073e9SAndroid Build Coastguard Worker   return torch::is_symint(out) ? out.cast<c10::SymInt>()
897*da0073e9SAndroid Build Coastguard Worker                                : c10::SymInt{py::cast<int64_t>(out)};
898*da0073e9SAndroid Build Coastguard Worker }
899*da0073e9SAndroid Build Coastguard Worker 
sym_strides(const c10::TensorImpl * self) const900*da0073e9SAndroid Build Coastguard Worker c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
901*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
902*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
903*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
904*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
905*da0073e9SAndroid Build Coastguard Worker   auto out = torchDispatchFromTensorImpl(
906*da0073e9SAndroid Build Coastguard Worker       self,
907*da0073e9SAndroid Build Coastguard Worker       "sym_stride",
908*da0073e9SAndroid Build Coastguard Worker       py::module::import("torch")
909*da0073e9SAndroid Build Coastguard Worker           .attr("ops")
910*da0073e9SAndroid Build Coastguard Worker           .attr("aten")
911*da0073e9SAndroid Build Coastguard Worker           .attr("sym_stride")
912*da0073e9SAndroid Build Coastguard Worker           .attr("default")
913*da0073e9SAndroid Build Coastguard Worker           .ptr(),
914*da0073e9SAndroid Build Coastguard Worker       "torch.ops.aten");
915*da0073e9SAndroid Build Coastguard Worker 
916*da0073e9SAndroid Build Coastguard Worker   if (out.is_none()) {
917*da0073e9SAndroid Build Coastguard Worker     return self->sym_strides_default();
918*da0073e9SAndroid Build Coastguard Worker   }
919*da0073e9SAndroid Build Coastguard Worker   // We need to squeeze SymIntNodes and ints into `SymInts`
920*da0073e9SAndroid Build Coastguard Worker   // since it's a format `sym_strides()` are stored in
921*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
922*da0073e9SAndroid Build Coastguard Worker       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
923*da0073e9SAndroid Build Coastguard Worker       "sym_strides must be a list or a tuple");
924*da0073e9SAndroid Build Coastguard Worker 
925*da0073e9SAndroid Build Coastguard Worker   auto updated_sym_strides =
926*da0073e9SAndroid Build Coastguard Worker       get_set_cached_attr<c10::SymInt>(self, "_sym_strides_capsule", out);
927*da0073e9SAndroid Build Coastguard Worker   return updated_sym_strides;
928*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS_PYBIND
929*da0073e9SAndroid Build Coastguard Worker }
930*da0073e9SAndroid Build Coastguard Worker 
reset_backward_hooks(const c10::TensorImpl * self) const931*da0073e9SAndroid Build Coastguard Worker void ConcretePyInterpreterVTable::reset_backward_hooks(
932*da0073e9SAndroid Build Coastguard Worker     const c10::TensorImpl* self) const {
933*da0073e9SAndroid Build Coastguard Worker   pybind11::gil_scoped_acquire gil;
934*da0073e9SAndroid Build Coastguard Worker   at::impl::MaybeSetTLSOnEntryGuard guard;
935*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
936*da0073e9SAndroid Build Coastguard Worker   Tensor self_t =
937*da0073e9SAndroid Build Coastguard Worker       Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
938*da0073e9SAndroid Build Coastguard Worker                  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
939*da0073e9SAndroid Build Coastguard Worker              unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
940*da0073e9SAndroid Build Coastguard Worker   auto self_p =
941*da0073e9SAndroid Build Coastguard Worker       py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
942*da0073e9SAndroid Build Coastguard Worker   PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
943*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS_PYBIND
944*da0073e9SAndroid Build Coastguard Worker }
945*da0073e9SAndroid Build Coastguard Worker 
name() const946*da0073e9SAndroid Build Coastguard Worker std::string ConcretePyInterpreterVTable::name() const {
947*da0073e9SAndroid Build Coastguard Worker   std::stringstream ss;
948*da0073e9SAndroid Build Coastguard Worker   ss << getPyInterpreter();
949*da0073e9SAndroid Build Coastguard Worker   return ss.str();
950*da0073e9SAndroid Build Coastguard Worker }
951*da0073e9SAndroid Build Coastguard Worker 
952*da0073e9SAndroid Build Coastguard Worker PyInterpreterHolder self_interpreter;
953*da0073e9SAndroid Build Coastguard Worker 
954*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
955*da0073e9SAndroid Build Coastguard Worker 
getTorchApiFunction(const c10::OperatorHandle & op)956*da0073e9SAndroid Build Coastguard Worker py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
957*da0073e9SAndroid Build Coastguard Worker   return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
958*da0073e9SAndroid Build Coastguard Worker     // Parse the name into namespace and name (no overload_name)
959*da0073e9SAndroid Build Coastguard Worker     // TODO: put this into the library
960*da0073e9SAndroid Build Coastguard Worker     const auto& schema = op.schema();
961*da0073e9SAndroid Build Coastguard Worker     const auto& qualified_name = op.operator_name().name;
962*da0073e9SAndroid Build Coastguard Worker     const auto& overload_name = schema.overload_name();
963*da0073e9SAndroid Build Coastguard Worker     auto pos = qualified_name.find("::");
964*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
965*da0073e9SAndroid Build Coastguard Worker     // Make me some null terminated strings
966*da0073e9SAndroid Build Coastguard Worker     std::string ns_str = qualified_name.substr(0, pos);
967*da0073e9SAndroid Build Coastguard Worker     const char* ns = ns_str.c_str();
968*da0073e9SAndroid Build Coastguard Worker     const char* func_name = qualified_name.c_str() + pos + strlen("::");
969*da0073e9SAndroid Build Coastguard Worker 
970*da0073e9SAndroid Build Coastguard Worker     py::handle torch_api_function =
971*da0073e9SAndroid Build Coastguard Worker         py::module::import("torch").attr("ops").attr(ns).attr(func_name);
972*da0073e9SAndroid Build Coastguard Worker     if (overload_name.empty()) {
973*da0073e9SAndroid Build Coastguard Worker       return torch_api_function.attr("default").ptr();
974*da0073e9SAndroid Build Coastguard Worker     } else {
975*da0073e9SAndroid Build Coastguard Worker       return torch_api_function.attr(overload_name.c_str()).ptr();
976*da0073e9SAndroid Build Coastguard Worker     }
977*da0073e9SAndroid Build Coastguard Worker   });
978*da0073e9SAndroid Build Coastguard Worker }
979*da0073e9SAndroid Build Coastguard Worker 
980*da0073e9SAndroid Build Coastguard Worker } // namespace torch::detail
981*da0073e9SAndroid Build Coastguard Worker 
getPyInterpreter()982*da0073e9SAndroid Build Coastguard Worker c10::impl::PyInterpreter* getPyInterpreter() {
983*da0073e9SAndroid Build Coastguard Worker   return torch::detail::self_interpreter.get();
984*da0073e9SAndroid Build Coastguard Worker }
985*da0073e9SAndroid Build Coastguard Worker 
isMainPyInterpreter()986*da0073e9SAndroid Build Coastguard Worker bool isMainPyInterpreter() {
987*da0073e9SAndroid Build Coastguard Worker   return torch::detail::self_interpreter.is_main_interpreter();
988*da0073e9SAndroid Build Coastguard Worker }
989