#include <ATen/core/PythonFallbackKernel.h> #include <ATen/core/PythonOpRegistrationTrampoline.h> #include <torch/csrc/PyInterpreter.h> #include <torch/csrc/THP.h> #include <torch/csrc/autograd/generated/VariableType.h> #include <torch/csrc/utils/python_arg_parser.h> #include <torch/csrc/utils/python_dispatch.h> #include <string> using namespace torch; using namespace at; using namespace c10; namespace torch::detail { namespace { // NB: This is a macro and not a template function (like it was before) // because passing in constexpr char* as template argument breaks some // versions of MSVC that are being used internally at Meta. // MSVC 14.16.27023 (vs2017_15.9) #define CONCRETE_GPU_TRACE(device_type, func_name, ...) \ at::impl::MaybeSetTLSOnEntryGuard guard; \ if (Py_IsInitialized()) { \ pybind11::gil_scoped_acquire gil; \ try { \ /* Masquerade hip as cuda because hip uses `torch.cuda` module. */ \ if (device_type == at::kHIP) { \ device_type = at::kCUDA; \ } \ std::string module_name = "torch." + DeviceTypeName(device_type, true); \ py::module mod = py::module::import(module_name.c_str()); \ py::object hook = \ mod.attr("_gpu_trace").attr(func_name).attr("fire_callbacks"); \ hook(__VA_ARGS__); \ } catch (const std::exception& e) { \ LOG(ERROR) << device_type \ << " trace hook execution failed: " << e.what(); \ } \ } struct ConcretePyInterpreterVTable final : public c10::impl::PyInterpreterVTable { std::string name() const override; void incref(PyObject* pyobj) const override; void decref(PyObject* pyobj, bool has_pyobj_slot) const override; // TODO: Need to make this work for StorageImpl too. I imagine I'll want to // operate upon a PyObjectSlot rather than a TensorImpl c10::intrusive_ptr<c10::TensorImpl> detach( const c10::TensorImpl* self) const override; void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) const override; void reportErrorCallback(PyObject* callback, DispatchKey key) const override; void python_dispatcher( const c10::OperatorHandle& op, c10::DispatchKeySet, torch::jit::Stack* stack) const override; // NB: this is defined in python_dispatch.cpp void python_op_registration_trampoline( const c10::OperatorHandle& op, c10::DispatchKey key, c10::DispatchKeySet keyset, torch::jit::Stack* stack, bool with_keyset, bool with_op) const override { torch::impl::dispatch::python_op_registration_trampoline_impl( op, key, keyset, stack, with_keyset, with_op); } void throw_abstract_impl_not_imported_error( std::string opname, const char* pymodule, const char* context) const override { py::gil_scoped_acquire gil; pybind11::module::import("torch._utils_internal") .attr("throw_abstract_impl_not_imported_error")( opname, pymodule, context); } bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat) const override; bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat) const override; bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override; c10::Device device(const c10::TensorImpl* self) const override; int64_t dim(const c10::TensorImpl* self) const override; c10::IntArrayRef strides(const c10::TensorImpl* self) const override; c10::IntArrayRef sizes(const c10::TensorImpl* self) const override; c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override; c10::Layout layout(const c10::TensorImpl* self) const override; int64_t numel(const c10::TensorImpl* self) const override; c10::SymInt sym_numel(const c10::TensorImpl* self) const override; c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override; c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override; void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event) const override { CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event); } void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event) const override { CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event); } void trace_gpu_event_record( at::DeviceType device_type, uintptr_t event, uintptr_t stream) const override { CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream); } void trace_gpu_event_wait( at::DeviceType device_type, uintptr_t event, uintptr_t stream) const override { CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream); } void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr) const override { CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr); } void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr) const override { CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr); } void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream) const override { CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream); } void trace_gpu_device_synchronization( at::DeviceType device_type) const override { CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks"); } void trace_gpu_stream_synchronization( at::DeviceType device_type, uintptr_t stream) const override { CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream); } void trace_gpu_event_synchronization( at::DeviceType device_type, uintptr_t event) const override { CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event); } void reset_backward_hooks(const c10::TensorImpl* self) const override; static ConcretePyInterpreterVTable* instance() { static ConcretePyInterpreterVTable s; return &s; } }; class PyInterpreterHolder { public: PyInterpreterHolder() : impl_(new c10::impl::PyInterpreter( ConcretePyInterpreterVTable::instance())), is_main_interpreter_( at::impl::PythonOpRegistrationTrampoline::registerInterpreter( impl_)) {} // NB: intentionally leaks the PyInterpreter, as there may still be // references to it that are live, living in objects that aren't being // destructed while Python is being cleaned up. ~PyInterpreterHolder() { impl_->disarm(); } c10::impl::PyInterpreter* get() const noexcept { return impl_; } bool is_main_interpreter() const noexcept { return is_main_interpreter_; } private: c10::impl::PyInterpreter* impl_; bool is_main_interpreter_; }; py::object torchDispatchFromTensorImpl( const c10::TensorImpl* self, const char* func_name, PyObject* torch_api_function, const char* module_name, // WARNING: MUST NOT BE TENSOR ARGS c10::SmallVector<py::object, 1> extra_args = {}) { if (torch_api_function == nullptr) { throw python_error(); } TORCH_CHECK( PyGILState_Check(), "GIL must be held before you call parseIValuesToPyArgsKwargs"); std::vector<PyObject*> overloaded_args; // TODO: there should be a shorter way to spell this // TODO: fix the constness of target at::Tensor self_t = at::Tensor( c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); // NB: this may not be a python tensor if you got here from a mode! // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); append_overloaded_tensor(&overloaded_args, self_p.ptr()); auto args = py::reinterpret_steal<py::object>( PyTuple_New(static_cast<Py_ssize_t>(1 + extra_args.size()))); PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); int64_t i = 1; for (auto& a : extra_args) { if (a.ptr() == nullptr) throw python_error(); PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr()); i++; } py::dict kwargs; return py::reinterpret_steal<py::object>( handle_torch_function_no_python_arg_parser( overloaded_args, args.ptr(), kwargs.ptr(), func_name, torch_api_function, module_name, TorchFunctionName::TorchDispatch)); } // NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] // Before calling PyInterpreter::decref, we must statically know if the // pyobj has a PyObjectSlot or not. // - If it has a PyObjectSlot, we need to be careful about PyObject resurrection // - If it does not have a PyObjectSlot, we can freely decref // One alternative to this is using PyObject_IsInstance // to get at this information. However, we don't want to risk an incorrect // `__instancecheck__` changing the semantics here. void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot) const { // Leak the pyobj if not initialized. This can happen if we are running // exit handlers that are destructing tensors with residual (owned) // PyObjects stored in them. if (!Py_IsInitialized()) return; pybind11::gil_scoped_acquire gil; // Two possibilities: // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or // Storage. Then we must be careful about PyObject resurrection (see // THPVariable_clear). // 2. We are decref-ing some other Python object. We don't do // PyObject resurrection on non-Tensors, so we just carry on as usual if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) { if (THPVariable_Check(pyobj)) { // It's still alive! This can happen if a weak ref resurrected // the PyObject without flipping ownership. At this point it is // too late to rescue the object, so just stub out the PyObject // so that it fails on subsequent uses. Don't raise an error here; // you're probably in a destructor. TORCH_WARN( "Deallocating Tensor that still has live PyObject references. " "This probably happened because you took out a weak reference to " "Tensor and didn't call _fix_weakref() after dereferencing it. " "Subsequent accesses to this tensor via the PyObject will now fail."); ((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>(); } else if (THPStorage_Check(pyobj)) { TORCH_WARN( "Deallocating UntypedStorage that still has live PyObject references. " "This probably happened because you took out a weak reference to " "UntypedStorage and didn't call _fix_weakref() after dereferencing it. " "Subsequent accesses to this storage via the PyObject will now fail."); ((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>(); } } Py_DECREF(pyobj); }; void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const { if (!Py_IsInitialized()) return; pybind11::gil_scoped_acquire gil; Py_INCREF(pyobj); }; bool isPythonTensor(const at::Tensor& tensor) { return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); } void ConcretePyInterpreterVTable::reportErrorCallback( PyObject* callback, DispatchKey key) const { py::gil_scoped_acquire g; auto func = py::reinterpret_borrow<py::object>(callback); // Not all DispatchKeys are pybind'ed into Python and we do not have infra // to ensure this, so just pass a string back to Python. func(c10::toString(key)); } void ConcretePyInterpreterVTable::dispatch( const c10::OperatorHandle& op, torch::jit::Stack* stack) const { const auto& schema = op.schema(); const auto num_arguments = schema.arguments().size(); auto arguments = torch::jit::pop(*stack, num_arguments); // The plan: convert all the arguments back into PyObjects, // extracting out the tensor handles, then call // handle_torch_function_no_python_arg_parser // NB: at the point arguments are pushed to the stack, ALL defaults // are already present py::gil_scoped_acquire g; std::vector<PyObject*> overloaded_args; py::handle torch_api_function_overload = getTorchApiFunction(op); // Find overloaded tensors for (const auto idx : c10::irange(arguments.size())) { const auto& ivalue = arguments[idx]; if (ivalue.isTensor()) { const auto& tensor = ivalue.toTensor(); if (isPythonTensor(tensor)) { append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); } } else if (ivalue.isList()) { const auto& list = ivalue.toListRef(); for (const auto jdx : c10::irange(list.size())) { const auto& nv = list[jdx]; if (nv.isTensor()) { const auto& tensor = nv.toTensor(); if (isPythonTensor(tensor)) { append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); } } } } } auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto args = std::move(args_kwargs.first); auto kwargs = std::move(args_kwargs.second); PyObject* obj = handle_torch_function_no_python_arg_parser( overloaded_args, args.ptr(), kwargs.ptr(), nullptr, torch_api_function_overload.ptr(), nullptr, TorchFunctionName::TorchDispatch); pushPyOutToStack( op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__"); } void ConcretePyInterpreterVTable::python_dispatcher( const c10::OperatorHandle& op, c10::DispatchKeySet ks, torch::jit::Stack* stack) const { py::gil_scoped_acquire g; py::handle torch_api_function_overload = getTorchApiFunction(op); // TODO: if necessary, can optimize to cache the cache lookup // TODO: if necessary, can optimize OpOverload to have slots auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache")); if (cache.ptr() == nullptr) { throw python_error(); } c10::DispatchKey k = ks.highestPriorityTypeId(); // TODO: allow this to be non-owning auto handler = py::reinterpret_borrow<py::object>( PyDict_GetItem(cache.ptr(), py::cast(k).ptr())); if (handler.ptr() == nullptr) { // Slow path handler = torch_api_function_overload.attr("_get_dispatch")(k); } if (py::isinstance<c10::DispatchKey>(handler)) { // NB: not redispatch, as that will permanently remove the python // dispatcher for subsequent redispatches op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack); return; } const auto& schema = op.schema(); const auto num_arguments = schema.arguments().size(); auto arguments = torch::jit::pop(*stack, num_arguments); auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto args = std::move(args_kwargs.first); auto kwargs = std::move(args_kwargs.second); py::object obj = py::reinterpret_steal<py::object>( PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr())); if (obj.ptr() == nullptr) { throw python_error(); } pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher"); } c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "detach", py::module::import("torch") .attr("ops") .attr("aten") .attr("detach") .attr("default") .ptr(), "torch.ops.aten"); TORCH_CHECK( THPVariable_Check(out.ptr()), "detach returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected Tensor"); const at::Tensor& res_t = THPVariable_Unpack(out.ptr()); return res_t.getIntrusivePtr(); } bool ConcretePyInterpreterVTable::is_contiguous( const c10::TensorImpl* self, at::MemoryFormat memory_format) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; py::object out; if (memory_format == at::MemoryFormat::Contiguous) { // For backwards compatibility out = torchDispatchFromTensorImpl( self, "is_contiguous", py::module::import("torch") .attr("ops") .attr("aten") .attr("is_contiguous") .attr("default") .ptr(), "torch.ops.aten"); } else { out = torchDispatchFromTensorImpl( self, "is_contiguous", py::module::import("torch") .attr("ops") .attr("aten") .attr("is_contiguous") .attr("memory_format") .ptr(), "torch.ops.aten", {py::cast(memory_format)}); } if (out.is_none()) { return self->is_contiguous_default(memory_format); } TORCH_CHECK( PyBool_Check(out.ptr()), "is_contiguous returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected bool"); return PyObject_IsTrue(out.ptr()); } bool ConcretePyInterpreterVTable::is_strides_like( const c10::TensorImpl* self, at::MemoryFormat memory_format) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "is_strides_like", py::module::import("torch") .attr("ops") .attr("aten") // NB: intentionally suffixed with _format to avoid // triggering matches against "_like" suffix .attr("is_strides_like_format") .attr("default") .ptr(), "torch.ops.aten", {py::cast(memory_format)}); if (out.is_none()) { return self->is_strides_like_default(memory_format); } TORCH_CHECK( PyBool_Check(out.ptr()), "is_strides_like_format returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected bool"); return PyObject_IsTrue(out.ptr()); } bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "is_non_overlapping_and_dense", py::module::import("torch") .attr("ops") .attr("aten") .attr("is_non_overlapping_and_dense") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { return self->is_non_overlapping_and_dense_default(); } TORCH_CHECK( PyBool_Check(out.ptr()), "is_non_overlapping_and_dense returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected bool"); return PyObject_IsTrue(out.ptr()); } int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "dim", py::module::import("torch") .attr("ops") .attr("aten") .attr("dim") .attr("default") .ptr(), "torch.ops.aten"); TORCH_CHECK( PyLong_Check(out.ptr()), "dim returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected int"); return THPUtils_unpackLong(out.ptr()); } c10::Device ConcretePyInterpreterVTable::device( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "device", py::module::import("torch") .attr("ops") .attr("prim") .attr("device") .attr("default") .ptr(), "torch.ops.prim"); return toDevice(out.ptr()); } static void set_tensor_attr_with_capsule( const c10::TensorImpl* tensor, py::capsule& capsule, const char* attr_name) { std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/false); TORCH_CHECK( mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); auto obj = mb_obj.value(); py::handle(obj).attr(attr_name) = capsule; } // Note [Tensor Subclass custom size/stride caching strategy] // Tensor subclasses can use __torch_dispatch__ to override size/stride calls. // However, this presents a problem: // (1) When you return a custom (maybe symbolic) size/stride // from python, we need to stash this fresh vector of ints/symints // somewhere so that it has the same lifetime as the tensor. // (2) If the subclass experiences a metadata mutation, // this stashed vector is no longer valid, so we need to allocate a fresh // buffer to store the new sizes the next time someone asks for them. // // We handle this in the same way that `TensorImpl::sizes_default()` // handles its buffer: we simply reallocate the buffer whenever // the number of dimensions changes due to a resize. // Notable, we do *not* reallocate the buffer if the values changed, // but the number of dimensions stayed the same (e.g. `.transpose_()`). template <typename T> static c10::ArrayRef<T> get_set_cached_attr( const c10::TensorImpl* tensor, const char* base_attr_name, const py::object& obj) { std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(getPyInterpreter()); TORCH_CHECK( mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); auto tensor_obj = mb_obj.value(); auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len"); bool is_buffer_allocated = false; size_t curr_size = 0; if (PyObject_HasAttrString(tensor_obj, buffer_len_attr_name.c_str())) { auto len_pyobj = py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()); curr_size = py::cast<size_t>(len_pyobj); is_buffer_allocated = true; } size_t new_size = py::len(obj); // We do the smallvector optimization here: any time the new_size is <=5, // we always allocate our buffer to size 5, so that if the next resize // is also to <=5 elements, we don't need to reallocate. // Note: I tried removing this optimization and tripped ASAN // in a batchnorm kernel here: // https://pipelinesghubeus21.actions.githubusercontent.com/mBh68xKhi8LyM7tp3vECvYXNFvuV4gyVGgmYCteuEZP9JH92QN/_apis/pipelines/1/runs/3373307/signedlogcontent/790?urlExpires=2023-09-15T21%3A13%3A51.4327798Z&urlSigningMethod=HMACV1&urlSignature=tDeX7ZqaARVU5NNwyr5yYqqkWq3A2j4z8FFdqYwGr0Q%3D // We should fix this instead. bool needs_resize = false; // We need to resize if: // (1) we haven't allocated our buffer at all yet // (2) Our buffer size is different from the new size // (note: we use the small vector optimization, where our buffer // is always allocated to at least size 5, and any resizes // within the <= 5 regime to not require a reallocation). auto is_smallvector = curr_size <= 5; needs_resize = !is_buffer_allocated || (is_smallvector && new_size > 5) || (!is_smallvector && curr_size != new_size); if (needs_resize) { // If our current buffer is not the right size (either because we haven't // allocated it yet, or there was a metadata mutation that changed the // number of dims of the tensor), allocate a fresh buffer. Note that this // will trash the previous buffer if there already was one, invalidating any // existing SymIntArrayRef's from an old .sym_size() call. auto new_buffer_size = new_size; if (new_size <= 5) { // This is the smallvector optimization new_buffer_size = 5; } T* ptr = new T[new_buffer_size]; auto capsule = py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<T*>(p); }); int64_t idx = 0; for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) { ptr[idx] = py::cast<T>(*it); } // Set the buffer set_tensor_attr_with_capsule(tensor, capsule, base_attr_name); // Set the len buffer py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()) = new_size; } else { TORCH_INTERNAL_ASSERT(PyObject_HasAttrString(tensor_obj, base_attr_name)); auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name); void* buffer_pycapsule = PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr); auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule); // Overwrite the buffer with our new values, but only if any of them changed // (due to a metadata mutation). // This is technically not thread safe, because the update happens lazily. // The original metadata mutation call on the tensor might have been thread // safe (e.g. a .resize_() call), but we won't actually mutate the size // buffer until the first call to .sizes() which the user might not access // in a thread-safe way. For now we are not explicitly locking, but maybe we // should. int64_t idx = 0; // Quick sanity assert that our buffer size is large enough // to compare against all the elements in the new buffer. size_t curr_buffer_size = 5; if (curr_buffer_size < curr_size) { curr_buffer_size = curr_size; } TORCH_INTERNAL_ASSERT(curr_buffer_size >= new_size); for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) { auto actual_val = py::cast<T>(*it); if constexpr (std::is_same_v<T, c10::SymInt>) { // if our SymInts are symbolic, we are *not* doing an equality check on // the symints. we just want to see if the nodes are the same. this is // because we don't want to introduce any guards here. if (!curr_buffer[idx].is_same(actual_val)) { curr_buffer[idx] = actual_val; } } else { if (curr_buffer[idx] != actual_val) { curr_buffer[idx] = actual_val; } } } } // The correct data is now stored at the buffer - read and return it. auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name); void* buffer_pycapsule = PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr); auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule); return c10::ArrayRef<T>(curr_buffer, new_size); } c10::IntArrayRef ConcretePyInterpreterVTable::strides( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "stride", py::module::import("torch") .attr("ops") .attr("aten") .attr("stride") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { TORCH_CHECK( !self->has_symbolic_sizes_strides(), "Cannot call strides on a tensor with symbolic shapes/strides"); return self->strides_default(); } TORCH_CHECK( py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), "strides must be a list or a tuple"); auto updated_strides = get_set_cached_attr<int64_t>(self, "_strides_capsule", out); return updated_strides; } c10::IntArrayRef ConcretePyInterpreterVTable::sizes( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; HANDLE_TH_ERRORS auto out = torchDispatchFromTensorImpl( self, "size", py::module::import("torch") .attr("ops") .attr("aten") .attr("size") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { TORCH_CHECK( !self->has_symbolic_sizes_strides(), "Cannot call sizes on a tensor with symbolic shapes/strides"); return self->sizes_default(); } TORCH_CHECK( py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), "sizes must be a list or a tuple"); auto updated_sizes = get_set_cached_attr<int64_t>(self, "_sizes_capsule", out); return updated_sizes; END_HANDLE_TH_ERRORS_PYBIND } c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; HANDLE_TH_ERRORS auto out = torchDispatchFromTensorImpl( self, "sym_size", py::module::import("torch") .attr("ops") .attr("aten") .attr("sym_size") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { return self->sym_sizes_default(); } TORCH_CHECK( py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), "sym_size must be a list or a tuple"); // See Note [Tensor Subclass custom size/stride caching strategy] auto updated_sym_sizes = get_set_cached_attr<c10::SymInt>(self, "_sym_sizes_capsule", out); return updated_sym_sizes; END_HANDLE_TH_ERRORS_PYBIND } c10::Layout ConcretePyInterpreterVTable::layout( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "layout", py::module::import("torch") .attr("ops") .attr("prim") .attr("layout") .attr("default") .ptr(), "torch.ops.prim"); TORCH_CHECK( THPLayout_Check(out.ptr()) || PyLong_Check(out.ptr()), "layout returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected Layout"); if (THPLayout_Check(out.ptr())) { return toLayout(out.ptr()); } else { return c10::Layout(py::cast<int64_t>(out)); } } int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "numel", py::module::import("torch") .attr("ops") .attr("aten") .attr("numel") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { TORCH_CHECK( !self->has_symbolic_sizes_strides(), "Cannot call sizes on a tensor with symbolic shapes/strides"); return self->numel_default(); } return py::cast<int64_t>(out); } c10::SymInt ConcretePyInterpreterVTable::sym_numel( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "sym_numel", py::module::import("torch") .attr("ops") .attr("aten") .attr("sym_numel") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { return self->sym_numel_default(); } return torch::is_symint(out) ? out.cast<c10::SymInt>() : c10::SymInt{py::cast<int64_t>(out)}; } c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "sym_storage_offset", py::module::import("torch") .attr("ops") .attr("aten") .attr("sym_storage_offset") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { return self->sym_storage_offset_default(); } return torch::is_symint(out) ? out.cast<c10::SymInt>() : c10::SymInt{py::cast<int64_t>(out)}; } c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; HANDLE_TH_ERRORS auto out = torchDispatchFromTensorImpl( self, "sym_stride", py::module::import("torch") .attr("ops") .attr("aten") .attr("sym_stride") .attr("default") .ptr(), "torch.ops.aten"); if (out.is_none()) { return self->sym_strides_default(); } // We need to squeeze SymIntNodes and ints into `SymInts` // since it's a format `sym_strides()` are stored in TORCH_CHECK( py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), "sym_strides must be a list or a tuple"); auto updated_sym_strides = get_set_cached_attr<c10::SymInt>(self, "_sym_strides_capsule", out); return updated_sym_strides; END_HANDLE_TH_ERRORS_PYBIND } void ConcretePyInterpreterVTable::reset_backward_hooks( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; HANDLE_TH_ERRORS Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None); END_HANDLE_TH_ERRORS_PYBIND } std::string ConcretePyInterpreterVTable::name() const { std::stringstream ss; ss << getPyInterpreter(); return ss.str(); } PyInterpreterHolder self_interpreter; } // anonymous namespace py::handle getTorchApiFunction(const c10::OperatorHandle& op) { return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* { // Parse the name into namespace and name (no overload_name) // TODO: put this into the library const auto& schema = op.schema(); const auto& qualified_name = op.operator_name().name; const auto& overload_name = schema.overload_name(); auto pos = qualified_name.find("::"); TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); // Make me some null terminated strings std::string ns_str = qualified_name.substr(0, pos); const char* ns = ns_str.c_str(); const char* func_name = qualified_name.c_str() + pos + strlen("::"); py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name); if (overload_name.empty()) { return torch_api_function.attr("default").ptr(); } else { return torch_api_function.attr(overload_name.c_str()).ptr(); } }); } } // namespace torch::detail c10::impl::PyInterpreter* getPyInterpreter() { return torch::detail::self_interpreter.get(); } bool isMainPyInterpreter() { return torch::detail::self_interpreter.is_main_interpreter(); }