xref: /aosp_15_r20/external/pytorch/torch/csrc/Module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/DeviceAccelerator.h>
2*da0073e9SAndroid Build Coastguard Worker #include <fmt/core.h>
3*da0073e9SAndroid Build Coastguard Worker #include <sys/types.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/python_headers.h>
5*da0073e9SAndroid Build Coastguard Worker #include <optional>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #ifndef _MSC_VER
8*da0073e9SAndroid Build Coastguard Worker #include <sys/socket.h>
9*da0073e9SAndroid Build Coastguard Worker #endif
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
12*da0073e9SAndroid Build Coastguard Worker #include <ATen/BlasBackend.h>
13*da0073e9SAndroid Build Coastguard Worker #include <ATen/CachedTensorUtils.h>
14*da0073e9SAndroid Build Coastguard Worker #include <ATen/DLConvertor.h>
15*da0073e9SAndroid Build Coastguard Worker #include <ATen/ExpandUtils.h>
16*da0073e9SAndroid Build Coastguard Worker #include <ATen/LegacyVmapMode.h>
17*da0073e9SAndroid Build Coastguard Worker #include <ATen/LinalgBackend.h>
18*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
19*da0073e9SAndroid Build Coastguard Worker #include <ATen/Utils.h>
20*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/Vitals.h>
21*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/AcceleratorHooksInterface.h>
22*da0073e9SAndroid Build Coastguard Worker #include <ATen/dlpack.h>
23*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/ConvUtils.h>
24*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/ForeachUtils.h>
25*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/Normalization.h>
26*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Device.h>
27*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DispatchKeySet.h>
28*da0073e9SAndroid Build Coastguard Worker #include <c10/util/AbortHandler.h>
29*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Backtrace.h>
30*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Logging.h>
31*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
32*da0073e9SAndroid Build Coastguard Worker #include <c10/util/thread_name.h>
33*da0073e9SAndroid Build Coastguard Worker #include <libshm.h>
34*da0073e9SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
35*da0073e9SAndroid Build Coastguard Worker #include <pybind11/stl.h>
36*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/THConcat.h>
37*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/pybind.h>
38*da0073e9SAndroid Build Coastguard Worker #include <cstdlib>
39*da0073e9SAndroid Build Coastguard Worker #include <iostream>
40*da0073e9SAndroid Build Coastguard Worker #include <unordered_map>
41*da0073e9SAndroid Build Coastguard Worker 
42*da0073e9SAndroid Build Coastguard Worker #include <ATen/ThreadLocalPythonObjects.h>
43*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/DataLoader.h>
44*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Device.h>
45*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Dtype.h>
46*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/DynamicTypes.h>
47*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Event.h>
48*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Generator.h>
49*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Layout.h>
50*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/MemoryFormat.h>
51*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/QScheme.h>
52*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Stream.h>
53*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/THP.h>
54*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/TypeInfo.h>
55*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/api/include/torch/python/init.h>
56*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/generated/python_return_types.h>
57*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_cpp_function.h>
58*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_enum_tag.h>
59*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_fft_functions.h>
60*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_function.h>
61*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_legacy_variable.h>
62*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_linalg_functions.h>
63*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_nested_functions.h>
64*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_nn_functions.h>
65*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_sparse_functions.h>
66*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_special_functions.h>
67*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/python_variable.h>
68*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/cpu/Module.h>
69*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/dynamo/init.h>
70*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/functorch/init.h>
71*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/fx/node.h>
72*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/inductor/aoti_runner/pybind.h>
73*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/instruction_counter/Module.h>
74*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/python/init.h>
75*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/python/python_ir.h>
76*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/python/python_tracer.h>
77*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/serialization/pickler.h>
78*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/lazy/python/init.h>
79*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/monitor/python_init.h>
80*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/mps/Module.h>
81*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/mtia/Module.h>
82*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/multiprocessing/init.h>
83*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/onnx/init.h>
84*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/profiler/python/init.h>
85*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/tensor/python_tensor.h>
86*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/disable_torch_function.h>
87*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/init.h>
88*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/pycfunction_helpers.h>
89*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/python_arg_parser.h>
90*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/python_compat.h>
91*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/python_dispatch.h>
92*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/python_strings.h>
93*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/tensor_dtypes.h>
94*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/tensor_layouts.h>
95*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/tensor_memoryformats.h>
96*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/tensor_new.h>
97*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/tensor_numpy.h>
98*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/tensor_qschemes.h>
99*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/verbose.h>
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/transformers/sdp_utils_cpp.h>
102*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/profiler/combined_traceback.h>
103*da0073e9SAndroid Build Coastguard Worker #include <sstream>
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
106*da0073e9SAndroid Build Coastguard Worker #include <ATen/cuda/CUDAConfig.h>
107*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/transformers/cuda/sdp_utils.h>
108*da0073e9SAndroid Build Coastguard Worker #ifdef __HIP_PLATFORM_AMD__
109*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/cudnn/hip/BatchNorm.h>
110*da0073e9SAndroid Build Coastguard Worker #else
111*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/cudnn/BatchNorm.h>
112*da0073e9SAndroid Build Coastguard Worker #endif
113*da0073e9SAndroid Build Coastguard Worker #endif
114*da0073e9SAndroid Build Coastguard Worker 
115*da0073e9SAndroid Build Coastguard Worker #ifdef USE_DISTRIBUTED
116*da0073e9SAndroid Build Coastguard Worker #ifdef USE_C10D
117*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/distributed/autograd/python_autograd.h>
118*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/distributed/c10d/c10d.h>
119*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/distributed/rpc/rpc.h>
120*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/distributed/rpc/testing/testing.h>
121*da0073e9SAndroid Build Coastguard Worker #endif
122*da0073e9SAndroid Build Coastguard Worker #endif
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker #if defined(USE_VALGRIND)
125*da0073e9SAndroid Build Coastguard Worker #include <callgrind.h>
126*da0073e9SAndroid Build Coastguard Worker #endif
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker namespace py = pybind11;
129*da0073e9SAndroid Build Coastguard Worker 
130*da0073e9SAndroid Build Coastguard Worker PyObject* module;
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker THPGenerator* THPDefaultCPUGenerator = nullptr;
133*da0073e9SAndroid Build Coastguard Worker 
134*da0073e9SAndroid Build Coastguard Worker ////////////////////////////////////////////////////////////////////////////////
135*da0073e9SAndroid Build Coastguard Worker ////////////////////////////////////////////////////////////////////////////////
136*da0073e9SAndroid Build Coastguard Worker 
THPModule_initNames(PyObject * self,PyObject * arg)137*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) {
138*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
139*da0073e9SAndroid Build Coastguard Worker   static std::vector<std::string> names;
140*da0073e9SAndroid Build Coastguard Worker 
141*da0073e9SAndroid Build Coastguard Worker   THPObjectPtr types(PySequence_Fast(arg, "expected a sequence"));
142*da0073e9SAndroid Build Coastguard Worker   if (!types)
143*da0073e9SAndroid Build Coastguard Worker     return nullptr;
144*da0073e9SAndroid Build Coastguard Worker 
145*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-branch-clone)
146*da0073e9SAndroid Build Coastguard Worker   auto num_classes = PySequence_Fast_GET_SIZE(types.get());
147*da0073e9SAndroid Build Coastguard Worker   names.reserve(names.size() + num_classes);
148*da0073e9SAndroid Build Coastguard Worker   for (Py_ssize_t i = 0; i < num_classes; i++) {
149*da0073e9SAndroid Build Coastguard Worker     PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
150*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(PyType_Check(obj), "expected a PyTypeObject");
151*da0073e9SAndroid Build Coastguard Worker     PyTypeObject* type = (PyTypeObject*)obj;
152*da0073e9SAndroid Build Coastguard Worker 
153*da0073e9SAndroid Build Coastguard Worker     THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
154*da0073e9SAndroid Build Coastguard Worker     if (!module_name)
155*da0073e9SAndroid Build Coastguard Worker       return nullptr;
156*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
157*da0073e9SAndroid Build Coastguard Worker         THPUtils_checkString(module_name.get()),
158*da0073e9SAndroid Build Coastguard Worker         "expected __module__ to be a string");
159*da0073e9SAndroid Build Coastguard Worker     std::string name = THPUtils_unpackString(module_name.get());
160*da0073e9SAndroid Build Coastguard Worker     names.emplace_back(name + "." + type->tp_name);
161*da0073e9SAndroid Build Coastguard Worker     type->tp_name = names.back().c_str();
162*da0073e9SAndroid Build Coastguard Worker   }
163*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
164*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
165*da0073e9SAndroid Build Coastguard Worker }
166*da0073e9SAndroid Build Coastguard Worker //
167*da0073e9SAndroid Build Coastguard Worker // Callback for python part. Used for additional initialization of python
168*da0073e9SAndroid Build Coastguard Worker // classes
THPModule_initExtension(PyObject * _unused,PyObject * shm_manager_path)169*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_initExtension(
170*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
171*da0073e9SAndroid Build Coastguard Worker     PyObject* shm_manager_path) {
172*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
173*da0073e9SAndroid Build Coastguard Worker #if !defined(FBCODE_CAFFE2) && !defined(__aarch64__)
174*da0073e9SAndroid Build Coastguard Worker   if (torch::get_cpp_stacktraces_enabled()) {
175*da0073e9SAndroid Build Coastguard Worker     c10::SetStackTraceFetcher([]() -> std::string {
176*da0073e9SAndroid Build Coastguard Worker       auto tb = torch::CapturedTraceback::gather(false, false, true);
177*da0073e9SAndroid Build Coastguard Worker       if (torch::get_symbolize_mode() == torch::unwind::Mode::addr2line) {
178*da0073e9SAndroid Build Coastguard Worker         LOG(WARNING)
179*da0073e9SAndroid Build Coastguard Worker             << "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
180*da0073e9SAndroid Build Coastguard Worker             << std::endl;
181*da0073e9SAndroid Build Coastguard Worker       }
182*da0073e9SAndroid Build Coastguard Worker       auto s_tbs = torch::symbolize({tb.get()});
183*da0073e9SAndroid Build Coastguard Worker       std::stringstream oss;
184*da0073e9SAndroid Build Coastguard Worker       oss << "C++ CapturedTraceback:" << std::endl;
185*da0073e9SAndroid Build Coastguard Worker       const auto& s_tb = s_tbs.tracebacks.at(0);
186*da0073e9SAndroid Build Coastguard Worker       for (auto idx : c10::irange(s_tb.size())) {
187*da0073e9SAndroid Build Coastguard Worker         // Skip the first few frames:
188*da0073e9SAndroid Build Coastguard Worker         //  #1 torch::CapturedTraceback::gather(bool, bool, bool)
189*da0073e9SAndroid Build Coastguard Worker         //  #2 THPModule_initExtension
190*da0073e9SAndroid Build Coastguard Worker         //  #3 THPModule_initExtension(_object*, _object*)::{lambda()#1}
191*da0073e9SAndroid Build Coastguard Worker         if (idx <= 3) {
192*da0073e9SAndroid Build Coastguard Worker           continue;
193*da0073e9SAndroid Build Coastguard Worker         }
194*da0073e9SAndroid Build Coastguard Worker         auto frame_id = s_tb[idx];
195*da0073e9SAndroid Build Coastguard Worker         const auto& frame = s_tbs.all_frames.at(frame_id);
196*da0073e9SAndroid Build Coastguard Worker         oss << "#" << idx << " " << frame.funcname << " from " << frame.filename
197*da0073e9SAndroid Build Coastguard Worker             << ":" << frame.lineno << std::endl;
198*da0073e9SAndroid Build Coastguard Worker       }
199*da0073e9SAndroid Build Coastguard Worker       return oss.str();
200*da0073e9SAndroid Build Coastguard Worker     });
201*da0073e9SAndroid Build Coastguard Worker   }
202*da0073e9SAndroid Build Coastguard Worker #endif
203*da0073e9SAndroid Build Coastguard Worker   if (!THPUtils_checkString(shm_manager_path)) {
204*da0073e9SAndroid Build Coastguard Worker     THPUtils_setError(
205*da0073e9SAndroid Build Coastguard Worker         "initialization error - expected bytes/string object as shm_manager_path!");
206*da0073e9SAndroid Build Coastguard Worker     return nullptr;
207*da0073e9SAndroid Build Coastguard Worker   }
208*da0073e9SAndroid Build Coastguard Worker   torch::utils::initializeLayouts();
209*da0073e9SAndroid Build Coastguard Worker   torch::utils::initializeMemoryFormats();
210*da0073e9SAndroid Build Coastguard Worker   torch::utils::initializeQSchemes();
211*da0073e9SAndroid Build Coastguard Worker   torch::utils::initializeDtypes();
212*da0073e9SAndroid Build Coastguard Worker   torch::tensors::initialize_python_bindings();
213*da0073e9SAndroid Build Coastguard Worker   std::string path = THPUtils_unpackString(shm_manager_path);
214*da0073e9SAndroid Build Coastguard Worker   libshm_init(path.c_str());
215*da0073e9SAndroid Build Coastguard Worker 
216*da0073e9SAndroid Build Coastguard Worker   auto module = THPObjectPtr(PyImport_ImportModule("torch"));
217*da0073e9SAndroid Build Coastguard Worker   if (!module)
218*da0073e9SAndroid Build Coastguard Worker     throw python_error();
219*da0073e9SAndroid Build Coastguard Worker 
220*da0073e9SAndroid Build Coastguard Worker   THPStorage_postInit(module);
221*da0073e9SAndroid Build Coastguard Worker   THPAutograd_initFunctions();
222*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
223*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
224*da0073e9SAndroid Build Coastguard Worker }
225*da0073e9SAndroid Build Coastguard Worker 
226*da0073e9SAndroid Build Coastguard Worker // The idea behind these two functions is to make it easy to test if we are
227*da0073e9SAndroid Build Coastguard Worker // built with ASAN: they're designed not to crash if ASAN is not enabled, but
228*da0073e9SAndroid Build Coastguard Worker // to trigger ASAN if it is enabled.  This lets us run a "canary" tests which
229*da0073e9SAndroid Build Coastguard Worker // checks if our build environment is misconfigured.
230*da0073e9SAndroid Build Coastguard Worker 
THPModule_crashIfCsrcASAN(PyObject * module,PyObject * arg)231*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_crashIfCsrcASAN(PyObject* module, PyObject* arg) {
232*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
233*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
234*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkLong(arg),
235*da0073e9SAndroid Build Coastguard Worker       "crash_if_csrc_asan expects an int, but got ",
236*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
237*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
238*da0073e9SAndroid Build Coastguard Worker   volatile char x[3];
239*da0073e9SAndroid Build Coastguard Worker   x[THPUtils_unpackInt(arg)] = 0;
240*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
241*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt32(x[0]);
242*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
243*da0073e9SAndroid Build Coastguard Worker }
244*da0073e9SAndroid Build Coastguard Worker 
THPModule_crashIfCsrcUBSAN(PyObject * module,PyObject * arg)245*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) {
246*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
247*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
248*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkLong(arg),
249*da0073e9SAndroid Build Coastguard Worker       "crash_if_csrc_ubsan expects an int, but got ",
250*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
251*da0073e9SAndroid Build Coastguard Worker   int32_t x = THPUtils_unpackInt(arg);
252*da0073e9SAndroid Build Coastguard Worker   double y = 1.0 / x;
253*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt32((int)y);
254*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
255*da0073e9SAndroid Build Coastguard Worker }
256*da0073e9SAndroid Build Coastguard Worker 
THPModule_crashIfvptrUBSAN(PyObject * module,PyObject * noarg)257*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) {
258*da0073e9SAndroid Build Coastguard Worker   // This code should work perfectly fine, as vtables are identical for Foo and
259*da0073e9SAndroid Build Coastguard Worker   // Baz unless rtti and ubsan are enabled
260*da0073e9SAndroid Build Coastguard Worker   struct Foo {
261*da0073e9SAndroid Build Coastguard Worker     virtual int bar() = 0;
262*da0073e9SAndroid Build Coastguard Worker     virtual ~Foo() = default;
263*da0073e9SAndroid Build Coastguard Worker   };
264*da0073e9SAndroid Build Coastguard Worker   struct Baz {
265*da0073e9SAndroid Build Coastguard Worker     virtual int bar() {
266*da0073e9SAndroid Build Coastguard Worker       return 17;
267*da0073e9SAndroid Build Coastguard Worker     }
268*da0073e9SAndroid Build Coastguard Worker     virtual ~Baz() = default;
269*da0073e9SAndroid Build Coastguard Worker   };
270*da0073e9SAndroid Build Coastguard Worker   Baz x{};
271*da0073e9SAndroid Build Coastguard Worker   auto y = static_cast<Foo*>(static_cast<void*>(&x));
272*da0073e9SAndroid Build Coastguard Worker   auto rc = y->bar();
273*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt32(rc);
274*da0073e9SAndroid Build Coastguard Worker }
275*da0073e9SAndroid Build Coastguard Worker 
THPModule_crashIfATenASAN(PyObject * module,PyObject * arg)276*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_crashIfATenASAN(PyObject* module, PyObject* arg) {
277*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
278*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
279*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkLong(arg),
280*da0073e9SAndroid Build Coastguard Worker       "crash_if_aten_asan expects an int, "
281*da0073e9SAndroid Build Coastguard Worker       "but got ",
282*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
283*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt32(at::_crash_if_asan(THPUtils_unpackInt(arg)));
284*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
285*da0073e9SAndroid Build Coastguard Worker }
286*da0073e9SAndroid Build Coastguard Worker 
THPModule_abort(PyObject * module,PyObject * noargs)287*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_abort(PyObject* module, PyObject* noargs) {
288*da0073e9SAndroid Build Coastguard Worker   std::terminate();
289*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
290*da0073e9SAndroid Build Coastguard Worker }
291*da0073e9SAndroid Build Coastguard Worker 
THPModule_crashIfDebugAssertsFail(PyObject * module,PyObject * arg)292*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_crashIfDebugAssertsFail(
293*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
294*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
295*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
296*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
297*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkLong(arg),
298*da0073e9SAndroid Build Coastguard Worker       "crash_if_debug_asserts_fail expects an int, but got ",
299*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
300*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
301*da0073e9SAndroid Build Coastguard Worker       THPUtils_unpackInt(arg) != 424242,
302*da0073e9SAndroid Build Coastguard Worker       "Expect anything but 424242 as an input for debug builds");
303*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt32(0);
304*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
305*da0073e9SAndroid Build Coastguard Worker }
306*da0073e9SAndroid Build Coastguard Worker 
THPModule_getNumThreads(PyObject * module,PyObject * noargs)307*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_getNumThreads(PyObject* module, PyObject* noargs) {
308*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt32(at::get_num_threads());
309*da0073e9SAndroid Build Coastguard Worker }
310*da0073e9SAndroid Build Coastguard Worker 
THPModule_setNumThreads(PyObject * module,PyObject * arg)311*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) {
312*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
313*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
314*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkLong(arg),
315*da0073e9SAndroid Build Coastguard Worker       "set_num_threads expects an int, but got ",
316*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
317*da0073e9SAndroid Build Coastguard Worker   int nthreads = (int)THPUtils_unpackLong(arg);
318*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(nthreads > 0, "set_num_threads expects a positive integer");
319*da0073e9SAndroid Build Coastguard Worker   at::set_num_threads(nthreads);
320*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
321*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
322*da0073e9SAndroid Build Coastguard Worker }
323*da0073e9SAndroid Build Coastguard Worker 
THPModule_getNumInteropThreads(PyObject * module,PyObject * noargs)324*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_getNumInteropThreads(
325*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
326*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
327*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt32(at::get_num_interop_threads());
328*da0073e9SAndroid Build Coastguard Worker }
329*da0073e9SAndroid Build Coastguard Worker 
THPModule_setNumInteropThreads(PyObject * module,PyObject * arg)330*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setNumInteropThreads(
331*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
332*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
333*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
334*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
335*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkLong(arg),
336*da0073e9SAndroid Build Coastguard Worker       "set_num_interop_threads expects an int, "
337*da0073e9SAndroid Build Coastguard Worker       "but got ",
338*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
339*da0073e9SAndroid Build Coastguard Worker   int nthreads = (int)THPUtils_unpackLong(arg);
340*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
341*da0073e9SAndroid Build Coastguard Worker       nthreads > 0, "set_num_interop_threads expects a positive integer");
342*da0073e9SAndroid Build Coastguard Worker   at::set_num_interop_threads(nthreads);
343*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
344*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
345*da0073e9SAndroid Build Coastguard Worker }
346*da0073e9SAndroid Build Coastguard Worker 
THPModule_setDefaultTensorType(PyObject * _unused,PyObject * type)347*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setDefaultTensorType(PyObject* _unused, PyObject* type) {
348*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
349*da0073e9SAndroid Build Coastguard Worker   torch::tensors::py_set_default_tensor_type(type);
350*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
351*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
352*da0073e9SAndroid Build Coastguard Worker }
353*da0073e9SAndroid Build Coastguard Worker 
THPModule_setDefaultDtype(PyObject * _unused,PyObject * dtype)354*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setDefaultDtype(PyObject* _unused, PyObject* dtype) {
355*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
356*da0073e9SAndroid Build Coastguard Worker   torch::tensors::py_set_default_dtype(dtype);
357*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
358*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
359*da0073e9SAndroid Build Coastguard Worker }
360*da0073e9SAndroid Build Coastguard Worker 
THPModule_swap_tensor_impl(PyObject * _unused,PyObject * args)361*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
362*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
363*da0073e9SAndroid Build Coastguard Worker   PyObject* a_ = nullptr;
364*da0073e9SAndroid Build Coastguard Worker   PyObject* b_ = nullptr;
365*da0073e9SAndroid Build Coastguard Worker   if (!PyArg_ParseTuple(args, "OO", &a_, &b_)) {
366*da0073e9SAndroid Build Coastguard Worker     return nullptr;
367*da0073e9SAndroid Build Coastguard Worker   }
368*da0073e9SAndroid Build Coastguard Worker 
369*da0073e9SAndroid Build Coastguard Worker   // Ensure we have Tensors
370*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(THPVariable_Check(a_));
371*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(THPVariable_Check(b_));
372*da0073e9SAndroid Build Coastguard Worker 
373*da0073e9SAndroid Build Coastguard Worker   THPVariable* a = reinterpret_cast<THPVariable*>(a_);
374*da0073e9SAndroid Build Coastguard Worker   THPVariable* b = reinterpret_cast<THPVariable*>(b_);
375*da0073e9SAndroid Build Coastguard Worker 
376*da0073e9SAndroid Build Coastguard Worker   // weak_use_count() adds 1 if use_count is non-zero
377*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
378*da0073e9SAndroid Build Coastguard Worker       a->cdata->weak_use_count() == 1,
379*da0073e9SAndroid Build Coastguard Worker       "Expected no weakrefs to t1's Tensor object but got  ",
380*da0073e9SAndroid Build Coastguard Worker       a->cdata->weak_use_count() - 1);
381*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
382*da0073e9SAndroid Build Coastguard Worker       b->cdata->weak_use_count() == 1,
383*da0073e9SAndroid Build Coastguard Worker       "Expected no weakrefs to t2's Tensor object but got  ",
384*da0073e9SAndroid Build Coastguard Worker       b->cdata->weak_use_count() - 1);
385*da0073e9SAndroid Build Coastguard Worker 
386*da0073e9SAndroid Build Coastguard Worker   // Swap the Tensor Impl
387*da0073e9SAndroid Build Coastguard Worker   c10::MaybeOwned<at::Tensor> tmp = a->cdata;
388*da0073e9SAndroid Build Coastguard Worker 
389*da0073e9SAndroid Build Coastguard Worker   // The TensorImpls contain PyObjectSlots that have a reference to the PyObject
390*da0073e9SAndroid Build Coastguard Worker   // associated with the TensorImpl. Swap this field as well.
391*da0073e9SAndroid Build Coastguard Worker   std::optional<PyObject*> mb_obj_a =
392*da0073e9SAndroid Build Coastguard Worker       a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
393*da0073e9SAndroid Build Coastguard Worker           getPyInterpreter(), /*ignore_hermetic_tls=*/false);
394*da0073e9SAndroid Build Coastguard Worker   std::optional<PyObject*> mb_obj_b =
395*da0073e9SAndroid Build Coastguard Worker       b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
396*da0073e9SAndroid Build Coastguard Worker           getPyInterpreter(), /*ignore_hermetic_tls=*/false);
397*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(
398*da0073e9SAndroid Build Coastguard Worker       mb_obj_a.has_value() && mb_obj_b.has_value(),
399*da0073e9SAndroid Build Coastguard Worker       "Both tensors should have PyObjects tagged by the current python interpreter");
400*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(mb_obj_a.value() == a_);
401*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(mb_obj_b.value() == b_);
402*da0073e9SAndroid Build Coastguard Worker 
403*da0073e9SAndroid Build Coastguard Worker   a->cdata = b->cdata;
404*da0073e9SAndroid Build Coastguard Worker   b->cdata = tmp;
405*da0073e9SAndroid Build Coastguard Worker 
406*da0073e9SAndroid Build Coastguard Worker   a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
407*da0073e9SAndroid Build Coastguard Worker       getPyInterpreter(), a_, c10::impl::PyInterpreterStatus::TAGGED_BY_US);
408*da0073e9SAndroid Build Coastguard Worker   b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
409*da0073e9SAndroid Build Coastguard Worker       getPyInterpreter(), b_, c10::impl::PyInterpreterStatus::TAGGED_BY_US);
410*da0073e9SAndroid Build Coastguard Worker 
411*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
412*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
413*da0073e9SAndroid Build Coastguard Worker }
414*da0073e9SAndroid Build Coastguard Worker 
THPModule_addDocStr(PyObject * _unused,PyObject * args)415*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
416*da0073e9SAndroid Build Coastguard Worker   // adds a __doc__ string to a function, similar to numpy's arr_add_docstring
417*da0073e9SAndroid Build Coastguard Worker   static std::vector<std::string> all_docs;
418*da0073e9SAndroid Build Coastguard Worker   PyObject* obj = nullptr;
419*da0073e9SAndroid Build Coastguard Worker   PyObject* doc_obj = nullptr;
420*da0073e9SAndroid Build Coastguard Worker   if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) {
421*da0073e9SAndroid Build Coastguard Worker     return nullptr;
422*da0073e9SAndroid Build Coastguard Worker   }
423*da0073e9SAndroid Build Coastguard Worker 
424*da0073e9SAndroid Build Coastguard Worker   const char* doc_str = "<invalid string>";
425*da0073e9SAndroid Build Coastguard Worker   if (THPUtils_checkString(doc_obj)) {
426*da0073e9SAndroid Build Coastguard Worker     all_docs.push_back(THPUtils_unpackString(doc_obj));
427*da0073e9SAndroid Build Coastguard Worker     doc_str = all_docs.back().c_str();
428*da0073e9SAndroid Build Coastguard Worker   }
429*da0073e9SAndroid Build Coastguard Worker 
430*da0073e9SAndroid Build Coastguard Worker   if (Py_TYPE(obj) == &PyCFunction_Type) {
431*da0073e9SAndroid Build Coastguard Worker     PyCFunctionObject* f = (PyCFunctionObject*)obj;
432*da0073e9SAndroid Build Coastguard Worker     if (f->m_ml->ml_doc) {
433*da0073e9SAndroid Build Coastguard Worker       return PyErr_Format(
434*da0073e9SAndroid Build Coastguard Worker           PyExc_RuntimeError,
435*da0073e9SAndroid Build Coastguard Worker           "function '%s' already has a docstring",
436*da0073e9SAndroid Build Coastguard Worker           f->m_ml->ml_name);
437*da0073e9SAndroid Build Coastguard Worker     }
438*da0073e9SAndroid Build Coastguard Worker     f->m_ml->ml_doc = doc_str;
439*da0073e9SAndroid Build Coastguard Worker   } else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
440*da0073e9SAndroid Build Coastguard Worker     PyMethodDescrObject* m = (PyMethodDescrObject*)obj;
441*da0073e9SAndroid Build Coastguard Worker     if (m->d_method->ml_doc) {
442*da0073e9SAndroid Build Coastguard Worker       return PyErr_Format(
443*da0073e9SAndroid Build Coastguard Worker           PyExc_RuntimeError,
444*da0073e9SAndroid Build Coastguard Worker           "method '%s' already has a docstring",
445*da0073e9SAndroid Build Coastguard Worker           m->d_method->ml_name);
446*da0073e9SAndroid Build Coastguard Worker     }
447*da0073e9SAndroid Build Coastguard Worker     m->d_method->ml_doc = doc_str;
448*da0073e9SAndroid Build Coastguard Worker   } else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) {
449*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
450*da0073e9SAndroid Build Coastguard Worker     PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj;
451*da0073e9SAndroid Build Coastguard Worker     if (m->d_getset->doc) {
452*da0073e9SAndroid Build Coastguard Worker       return PyErr_Format(
453*da0073e9SAndroid Build Coastguard Worker           PyExc_RuntimeError,
454*da0073e9SAndroid Build Coastguard Worker           "attribute '%s' already has a docstring",
455*da0073e9SAndroid Build Coastguard Worker           m->d_getset->name);
456*da0073e9SAndroid Build Coastguard Worker     }
457*da0073e9SAndroid Build Coastguard Worker     m->d_getset->doc = doc_str;
458*da0073e9SAndroid Build Coastguard Worker   } else if (Py_TYPE(obj) == &PyType_Type) {
459*da0073e9SAndroid Build Coastguard Worker     PyTypeObject* t = (PyTypeObject*)obj;
460*da0073e9SAndroid Build Coastguard Worker     if (t->tp_doc) {
461*da0073e9SAndroid Build Coastguard Worker       return PyErr_Format(
462*da0073e9SAndroid Build Coastguard Worker           PyExc_RuntimeError, "Type '%s' already has a docstring", t->tp_name);
463*da0073e9SAndroid Build Coastguard Worker     }
464*da0073e9SAndroid Build Coastguard Worker     t->tp_doc = doc_str;
465*da0073e9SAndroid Build Coastguard Worker   } else {
466*da0073e9SAndroid Build Coastguard Worker     return PyErr_Format(
467*da0073e9SAndroid Build Coastguard Worker         PyExc_TypeError,
468*da0073e9SAndroid Build Coastguard Worker         "don't know how to add docstring to type '%s'",
469*da0073e9SAndroid Build Coastguard Worker         Py_TYPE(obj)->tp_name);
470*da0073e9SAndroid Build Coastguard Worker   }
471*da0073e9SAndroid Build Coastguard Worker 
472*da0073e9SAndroid Build Coastguard Worker   Py_INCREF(obj);
473*da0073e9SAndroid Build Coastguard Worker   return obj;
474*da0073e9SAndroid Build Coastguard Worker }
475*da0073e9SAndroid Build Coastguard Worker 
THPModule_inferSize(PyObject * _unused,PyObject * args)476*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_inferSize(PyObject* _unused, PyObject* args) {
477*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
478*da0073e9SAndroid Build Coastguard Worker   Py_ssize_t num_args = args ? (Py_ssize_t)PyTuple_Size(args) : 0;
479*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(num_args == 2, "expected exactly 2 arguments");
480*da0073e9SAndroid Build Coastguard Worker   PyObject* arg1 = PyTuple_GET_ITEM(args, 0);
481*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(THPSize_Check(arg1), "expected a torch.Size as argument 1");
482*da0073e9SAndroid Build Coastguard Worker   PyObject* arg2 = PyTuple_GET_ITEM(args, 1);
483*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(THPSize_Check(arg2), "expected a torch.Size as argument 2");
484*da0073e9SAndroid Build Coastguard Worker 
485*da0073e9SAndroid Build Coastguard Worker   auto size1 = THPUtils_unpackLongs(arg1);
486*da0073e9SAndroid Build Coastguard Worker   auto size2 = THPUtils_unpackLongs(arg2);
487*da0073e9SAndroid Build Coastguard Worker   auto sizes = at::infer_size(size1, size2);
488*da0073e9SAndroid Build Coastguard Worker   return THPSize_NewFromSizes(static_cast<int64_t>(sizes.size()), sizes.data());
489*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
490*da0073e9SAndroid Build Coastguard Worker }
491*da0073e9SAndroid Build Coastguard Worker 
THPModule_setBackcompatBroadcastWarn(PyObject * module,PyObject * arg)492*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setBackcompatBroadcastWarn(
493*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
494*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
495*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
496*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
497*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
498*da0073e9SAndroid Build Coastguard Worker       "set_backcompat_broadcast_warn expects a bool, "
499*da0073e9SAndroid Build Coastguard Worker       "but got ",
500*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
501*da0073e9SAndroid Build Coastguard Worker   setBackCompatBroadcastWarn(arg == Py_True);
502*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
503*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
504*da0073e9SAndroid Build Coastguard Worker }
505*da0073e9SAndroid Build Coastguard Worker 
THPModule_getBackcompatBroadcastWarn(PyObject * module,PyObject * noargs)506*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_getBackcompatBroadcastWarn(
507*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
508*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
509*da0073e9SAndroid Build Coastguard Worker   if (getBackCompatBroadcastWarn())
510*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
511*da0073e9SAndroid Build Coastguard Worker   else
512*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
513*da0073e9SAndroid Build Coastguard Worker }
514*da0073e9SAndroid Build Coastguard Worker 
THPModule_setBackcompatKeepdimWarn(PyObject * module,PyObject * arg)515*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setBackcompatKeepdimWarn(
516*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
517*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
518*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
519*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
520*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
521*da0073e9SAndroid Build Coastguard Worker       "set_backcompat_keepdim_warn expects a bool, "
522*da0073e9SAndroid Build Coastguard Worker       "but got ",
523*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
524*da0073e9SAndroid Build Coastguard Worker   setBackCompatKeepdimWarn(arg == Py_True);
525*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
526*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
527*da0073e9SAndroid Build Coastguard Worker }
528*da0073e9SAndroid Build Coastguard Worker 
THPModule_getBackcompatKeepdimWarn(PyObject * module,PyObject * noargs)529*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_getBackcompatKeepdimWarn(
530*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
531*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
532*da0073e9SAndroid Build Coastguard Worker   if (getBackCompatKeepdimWarn())
533*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
534*da0073e9SAndroid Build Coastguard Worker   else
535*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
536*da0073e9SAndroid Build Coastguard Worker }
537*da0073e9SAndroid Build Coastguard Worker 
THPModule_hasDistributed(PyObject * _unused,PyObject * noargs)538*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) {
539*da0073e9SAndroid Build Coastguard Worker #ifdef USE_DISTRIBUTED
540*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_TRUE;
541*da0073e9SAndroid Build Coastguard Worker #else
542*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
543*da0073e9SAndroid Build Coastguard Worker #endif
544*da0073e9SAndroid Build Coastguard Worker }
545*da0073e9SAndroid Build Coastguard Worker 
THPModule_showConfig(PyObject * module,PyObject * noargs)546*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) {
547*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
548*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(at::show_config());
549*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
550*da0073e9SAndroid Build Coastguard Worker }
551*da0073e9SAndroid Build Coastguard Worker 
THPModule_cxxFlags(PyObject * module,PyObject * noargs)552*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_cxxFlags(PyObject* module, PyObject* noargs) {
553*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
554*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(at::get_cxx_flags());
555*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
556*da0073e9SAndroid Build Coastguard Worker }
557*da0073e9SAndroid Build Coastguard Worker 
THPModule_parallelInfo(PyObject * module,PyObject * noargs)558*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_parallelInfo(PyObject* module, PyObject* noargs) {
559*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
560*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(at::get_parallel_info());
561*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
562*da0073e9SAndroid Build Coastguard Worker }
563*da0073e9SAndroid Build Coastguard Worker 
THPModule_getCpuCapability(PyObject * module,PyObject * noargs)564*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_getCpuCapability(
565*da0073e9SAndroid Build Coastguard Worker     PyObject* module,
566*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
567*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
568*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(at::get_cpu_capability());
569*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
570*da0073e9SAndroid Build Coastguard Worker }
571*da0073e9SAndroid Build Coastguard Worker 
DLPack_Capsule_Destructor(PyObject * data)572*da0073e9SAndroid Build Coastguard Worker void DLPack_Capsule_Destructor(PyObject* data) {
573*da0073e9SAndroid Build Coastguard Worker   if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) {
574*da0073e9SAndroid Build Coastguard Worker     // early out, see DLPack spec: if a consuming library sets the capsule
575*da0073e9SAndroid Build Coastguard Worker     // name to something else, they own it and we don't need to do anything
576*da0073e9SAndroid Build Coastguard Worker     return;
577*da0073e9SAndroid Build Coastguard Worker   }
578*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
579*da0073e9SAndroid Build Coastguard Worker   // Causes overheads for validity checks again, but this case is rare
580*da0073e9SAndroid Build Coastguard Worker   // since consuming libraries should rename the capsule according to spec.
581*da0073e9SAndroid Build Coastguard Worker   // Note that this cannot set a python error (we checked validity above),
582*da0073e9SAndroid Build Coastguard Worker   // so we don't need to handle python error state here.
583*da0073e9SAndroid Build Coastguard Worker   DLManagedTensor* dlMTensor =
584*da0073e9SAndroid Build Coastguard Worker       (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
585*da0073e9SAndroid Build Coastguard Worker   // the dlMTensor has not been consumed, call deleter ourselves.
586*da0073e9SAndroid Build Coastguard Worker   // DLPack spec mentions that deleter may be NULL, but deleter from
587*da0073e9SAndroid Build Coastguard Worker   // `at::toDLPack` is never NULL, so no need for an additional check here.
588*da0073e9SAndroid Build Coastguard Worker   dlMTensor->deleter(dlMTensor);
589*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS_RET()
590*da0073e9SAndroid Build Coastguard Worker }
591*da0073e9SAndroid Build Coastguard Worker 
THPModule_toDLPack(PyObject * _unused,PyObject * data)592*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) {
593*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
594*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor");
595*da0073e9SAndroid Build Coastguard Worker   DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data));
596*da0073e9SAndroid Build Coastguard Worker   return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
597*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
598*da0073e9SAndroid Build Coastguard Worker }
599*da0073e9SAndroid Build Coastguard Worker 
THPModule_fromDLPack(PyObject * _unused,PyObject * data)600*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
601*da0073e9SAndroid Build Coastguard Worker   using namespace torch::autograd;
602*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
603*da0073e9SAndroid Build Coastguard Worker   auto tensor = torch::utils::tensor_fromDLPack(data);
604*da0073e9SAndroid Build Coastguard Worker   return THPVariable_Wrap(tensor);
605*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
606*da0073e9SAndroid Build Coastguard Worker }
607*da0073e9SAndroid Build Coastguard Worker 
THModule_getCppBacktrace(PyObject * _unused,PyObject * args)608*da0073e9SAndroid Build Coastguard Worker PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
609*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
610*da0073e9SAndroid Build Coastguard Worker   size_t frames_to_skip = 0;
611*da0073e9SAndroid Build Coastguard Worker   size_t maximum_number_of_frames = 0;
612*da0073e9SAndroid Build Coastguard Worker   if (!PyArg_ParseTuple(
613*da0073e9SAndroid Build Coastguard Worker           args, "LL", &frames_to_skip, &maximum_number_of_frames)) {
614*da0073e9SAndroid Build Coastguard Worker     return nullptr;
615*da0073e9SAndroid Build Coastguard Worker   }
616*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(
617*da0073e9SAndroid Build Coastguard Worker       c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true));
618*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
619*da0073e9SAndroid Build Coastguard Worker }
620*da0073e9SAndroid Build Coastguard Worker 
THModule_rename_privateuse1_backend(PyObject * _unused,PyObject * arg)621*da0073e9SAndroid Build Coastguard Worker static PyObject* THModule_rename_privateuse1_backend(
622*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
623*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
624*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
625*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
626*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkString(arg),
627*da0073e9SAndroid Build Coastguard Worker       "_rename_privateuse1_backend expects a str, but got ",
628*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
629*da0073e9SAndroid Build Coastguard Worker   const std::string backend_name = THPUtils_unpackString(arg);
630*da0073e9SAndroid Build Coastguard Worker   c10::register_privateuse1_backend(backend_name);
631*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
632*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
633*da0073e9SAndroid Build Coastguard Worker }
634*da0073e9SAndroid Build Coastguard Worker 
THModule_get_privateuse1_backend_name(PyObject * _unused,PyObject * arg)635*da0073e9SAndroid Build Coastguard Worker static PyObject* THModule_get_privateuse1_backend_name(
636*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
637*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
638*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
639*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(c10::get_privateuse1_backend());
640*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
641*da0073e9SAndroid Build Coastguard Worker }
642*da0073e9SAndroid Build Coastguard Worker 
THPModule_setAllowTF32CuDNN(PyObject * _unused,PyObject * arg)643*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
644*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
645*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
646*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
647*da0073e9SAndroid Build Coastguard Worker       "set_allow_tf32_cublas expects a bool, "
648*da0073e9SAndroid Build Coastguard Worker       "but got ",
649*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
650*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setAllowTF32CuDNN(arg == Py_True);
651*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
652*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
653*da0073e9SAndroid Build Coastguard Worker }
654*da0073e9SAndroid Build Coastguard Worker 
THPModule_allowTF32CuDNN(PyObject * _unused,PyObject * noargs)655*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
656*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().allowTF32CuDNN())
657*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
658*da0073e9SAndroid Build Coastguard Worker   else
659*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
660*da0073e9SAndroid Build Coastguard Worker }
661*da0073e9SAndroid Build Coastguard Worker 
THPModule_setFloat32MatmulPrecision(PyObject * _unused,PyObject * arg)662*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setFloat32MatmulPrecision(
663*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
664*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
665*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
666*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
667*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkString(arg),
668*da0073e9SAndroid Build Coastguard Worker       "set_float32_matmul_precision expects a str, "
669*da0073e9SAndroid Build Coastguard Worker       "but got ",
670*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
671*da0073e9SAndroid Build Coastguard Worker   std::string s = THPUtils_unpackString(arg);
672*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setFloat32MatmulPrecision(s);
673*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
674*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
675*da0073e9SAndroid Build Coastguard Worker }
676*da0073e9SAndroid Build Coastguard Worker 
THPModule_float32MatmulPrecision(PyObject * _unused,PyObject * noargs)677*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_float32MatmulPrecision(
678*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
679*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
680*da0073e9SAndroid Build Coastguard Worker   std::string s = "highest";
681*da0073e9SAndroid Build Coastguard Worker   auto p = at::globalContext().float32MatmulPrecision();
682*da0073e9SAndroid Build Coastguard Worker   if (p == at::Float32MatmulPrecision::HIGH) {
683*da0073e9SAndroid Build Coastguard Worker     s = "high";
684*da0073e9SAndroid Build Coastguard Worker   } else if (p == at::Float32MatmulPrecision::MEDIUM) {
685*da0073e9SAndroid Build Coastguard Worker     s = "medium";
686*da0073e9SAndroid Build Coastguard Worker   }
687*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(s);
688*da0073e9SAndroid Build Coastguard Worker }
THPModule_setSDPUseFlash(PyObject * _unused,PyObject * arg)689*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setSDPUseFlash(PyObject* _unused, PyObject* arg) {
690*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
691*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
692*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
693*da0073e9SAndroid Build Coastguard Worker       "set_sdp_use_math expects a bool, "
694*da0073e9SAndroid Build Coastguard Worker       "but got ",
695*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
696*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setSDPUseFlash(arg == Py_True);
697*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
698*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
699*da0073e9SAndroid Build Coastguard Worker }
THPModule_userEnabledFlashSDP(PyObject * _unused,PyObject * noargs)700*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
701*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledFlashSDP())
702*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
703*da0073e9SAndroid Build Coastguard Worker   else
704*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
705*da0073e9SAndroid Build Coastguard Worker }
THPModule_setSDPUseMemEfficient(PyObject * _unused,PyObject * arg)706*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
707*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
708*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
709*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
710*da0073e9SAndroid Build Coastguard Worker       "set_sdp_use_math expects a bool, "
711*da0073e9SAndroid Build Coastguard Worker       "but got ",
712*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
713*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setSDPUseMemEfficient(arg == Py_True);
714*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
715*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
716*da0073e9SAndroid Build Coastguard Worker }
userEnabledMemEfficientSDP(PyObject * _unused,PyObject * noargs)717*da0073e9SAndroid Build Coastguard Worker PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
718*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledMemEfficientSDP())
719*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
720*da0073e9SAndroid Build Coastguard Worker   else
721*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
722*da0073e9SAndroid Build Coastguard Worker }
THPModule_setSDPUseMath(PyObject * _unused,PyObject * arg)723*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
724*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
725*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
726*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
727*da0073e9SAndroid Build Coastguard Worker       "set_sdp_use_math expects a bool, "
728*da0073e9SAndroid Build Coastguard Worker       "but got ",
729*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
730*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setSDPUseMath(arg == Py_True);
731*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
732*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
733*da0073e9SAndroid Build Coastguard Worker }
THPModule_userEnabledMathSDP(PyObject * _unused,PyObject * noargs)734*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
735*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledMathSDP())
736*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
737*da0073e9SAndroid Build Coastguard Worker   else
738*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
739*da0073e9SAndroid Build Coastguard Worker }
THPModule_setAllowFP16BF16ReductionMathSDP(PyObject * _unused,PyObject * arg)740*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setAllowFP16BF16ReductionMathSDP(
741*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
742*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
743*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
744*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
745*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
746*da0073e9SAndroid Build Coastguard Worker       "set_sdp_use_math expects a bool, "
747*da0073e9SAndroid Build Coastguard Worker       "but got ",
748*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
749*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setAllowFP16BF16ReductionMathSDP(arg == Py_True);
750*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
751*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
752*da0073e9SAndroid Build Coastguard Worker }
THPModule_allowFP16BF16ReductionMathSDP(PyObject * _unused,PyObject * noargs)753*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_allowFP16BF16ReductionMathSDP(
754*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
755*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
756*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().allowFP16BF16ReductionMathSDP())
757*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
758*da0073e9SAndroid Build Coastguard Worker   else
759*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
760*da0073e9SAndroid Build Coastguard Worker }
THPModule_setSDPUseOverrideable(PyObject * _unused,PyObject * arg)761*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setSDPUseOverrideable(PyObject* _unused, PyObject* arg) {
762*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
763*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
764*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
765*da0073e9SAndroid Build Coastguard Worker       "set_sdp_use_overrideable expects a bool, "
766*da0073e9SAndroid Build Coastguard Worker       "but got ",
767*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
768*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setSDPUseOverrideable(arg == Py_True);
769*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
770*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
771*da0073e9SAndroid Build Coastguard Worker }
THPModule_userEnabledOverrideableSDP(PyObject * _unused,PyObject * noargs)772*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_userEnabledOverrideableSDP(
773*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
774*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
775*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledOverrideableSDP())
776*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
777*da0073e9SAndroid Build Coastguard Worker   else
778*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
779*da0073e9SAndroid Build Coastguard Worker }
THPModule_setSDPUseCuDNN(PyObject * _unused,PyObject * arg)780*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setSDPUseCuDNN(PyObject* _unused, PyObject* arg) {
781*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
782*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
783*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
784*da0073e9SAndroid Build Coastguard Worker       "set_sdp_use_cudnn expects a bool, "
785*da0073e9SAndroid Build Coastguard Worker       "but got %s",
786*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
787*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setSDPUseCuDNN(arg == Py_True);
788*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
789*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
790*da0073e9SAndroid Build Coastguard Worker }
THPModule_userEnabledCuDNNSDP(PyObject * _unused,PyObject * noargs)791*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_userEnabledCuDNNSDP(PyObject* _unused, PyObject* noargs) {
792*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledCuDNNSDP())
793*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
794*da0073e9SAndroid Build Coastguard Worker   else
795*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
796*da0073e9SAndroid Build Coastguard Worker }
797*da0073e9SAndroid Build Coastguard Worker 
THPModule_setUserEnabledCuDNN(PyObject * _unused,PyObject * arg)798*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setUserEnabledCuDNN(PyObject* _unused, PyObject* arg) {
799*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
800*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
801*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
802*da0073e9SAndroid Build Coastguard Worker       "set_enabled_cudnn expects a bool, "
803*da0073e9SAndroid Build Coastguard Worker       "but got ",
804*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
805*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setUserEnabledCuDNN(arg == Py_True);
806*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
807*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
808*da0073e9SAndroid Build Coastguard Worker }
809*da0073e9SAndroid Build Coastguard Worker 
THPModule_userEnabledCuDNN(PyObject * _unused,PyObject * noargs)810*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_userEnabledCuDNN(PyObject* _unused, PyObject* noargs) {
811*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledCuDNN())
812*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
813*da0073e9SAndroid Build Coastguard Worker   else
814*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
815*da0073e9SAndroid Build Coastguard Worker }
816*da0073e9SAndroid Build Coastguard Worker 
THPModule_setUserEnabledMkldnn(PyObject * _unused,PyObject * arg)817*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setUserEnabledMkldnn(PyObject* _unused, PyObject* arg) {
818*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
819*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
820*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
821*da0073e9SAndroid Build Coastguard Worker       "set_enabled_mkldnn expects a bool, "
822*da0073e9SAndroid Build Coastguard Worker       "but got ",
823*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
824*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setUserEnabledMkldnn(arg == Py_True);
825*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
826*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
827*da0073e9SAndroid Build Coastguard Worker }
828*da0073e9SAndroid Build Coastguard Worker 
THPModule_userEnabledMkldnn(PyObject * _unused,PyObject * noargs)829*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_userEnabledMkldnn(PyObject* _unused, PyObject* noargs) {
830*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledMkldnn())
831*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
832*da0073e9SAndroid Build Coastguard Worker   else
833*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
834*da0073e9SAndroid Build Coastguard Worker }
835*da0073e9SAndroid Build Coastguard Worker 
THPModule_setDeterministicCuDNN(PyObject * _unused,PyObject * arg)836*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setDeterministicCuDNN(PyObject* _unused, PyObject* arg) {
837*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
838*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
839*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
840*da0073e9SAndroid Build Coastguard Worker       "set_deterministic_cudnn expects a bool, "
841*da0073e9SAndroid Build Coastguard Worker       "but got ",
842*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
843*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setDeterministicCuDNN(arg == Py_True);
844*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
845*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
846*da0073e9SAndroid Build Coastguard Worker }
847*da0073e9SAndroid Build Coastguard Worker 
THPModule_deterministicCuDNN(PyObject * _unused,PyObject * noargs)848*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_deterministicCuDNN(PyObject* _unused, PyObject* noargs) {
849*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().deterministicCuDNN())
850*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
851*da0073e9SAndroid Build Coastguard Worker   else
852*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
853*da0073e9SAndroid Build Coastguard Worker }
854*da0073e9SAndroid Build Coastguard Worker 
THPModule_setDeterministicMkldnn(PyObject * _unused,PyObject * arg)855*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setDeterministicMkldnn(PyObject* _unused, PyObject* arg) {
856*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
857*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
858*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
859*da0073e9SAndroid Build Coastguard Worker       "set_deterministic_mkldnn expects a bool, "
860*da0073e9SAndroid Build Coastguard Worker       "but got ",
861*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
862*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setDeterministicMkldnn(arg == Py_True);
863*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
864*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
865*da0073e9SAndroid Build Coastguard Worker }
866*da0073e9SAndroid Build Coastguard Worker 
THPModule_deterministicMkldnn(PyObject * _unused,PyObject * noargs)867*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_deterministicMkldnn(PyObject* _unused, PyObject* noargs) {
868*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().deterministicMkldnn())
869*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
870*da0073e9SAndroid Build Coastguard Worker   else
871*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
872*da0073e9SAndroid Build Coastguard Worker }
873*da0073e9SAndroid Build Coastguard Worker 
THPModule_setDeterministicAlgorithms(PyObject * _unused,PyObject * args,PyObject * kwargs)874*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setDeterministicAlgorithms(
875*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
876*da0073e9SAndroid Build Coastguard Worker     PyObject* args,
877*da0073e9SAndroid Build Coastguard Worker     PyObject* kwargs) {
878*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
879*da0073e9SAndroid Build Coastguard Worker   static torch::PythonArgParser parser(
880*da0073e9SAndroid Build Coastguard Worker       {"_set_deterministic_algorithms(bool mode, *, bool warn_only=False)"});
881*da0073e9SAndroid Build Coastguard Worker   torch::ParsedArgs<2> parsed_args{};
882*da0073e9SAndroid Build Coastguard Worker   auto r = parser.parse(args, kwargs, parsed_args);
883*da0073e9SAndroid Build Coastguard Worker   bool mode = r.toBool(0);
884*da0073e9SAndroid Build Coastguard Worker   bool warn_only = r.toBool(1);
885*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setDeterministicAlgorithms(mode, warn_only);
886*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
887*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
888*da0073e9SAndroid Build Coastguard Worker }
889*da0073e9SAndroid Build Coastguard Worker 
THPModule_deterministicAlgorithms(PyObject * _unused,PyObject * noargs)890*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_deterministicAlgorithms(
891*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
892*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
893*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().deterministicAlgorithms()) {
894*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
895*da0073e9SAndroid Build Coastguard Worker   }
896*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
897*da0073e9SAndroid Build Coastguard Worker }
898*da0073e9SAndroid Build Coastguard Worker 
THPModule_deterministicAlgorithmsWarnOnly(PyObject * _unused,PyObject * noargs)899*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_deterministicAlgorithmsWarnOnly(
900*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
901*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
902*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().deterministicAlgorithmsWarnOnly()) {
903*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
904*da0073e9SAndroid Build Coastguard Worker   }
905*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
906*da0073e9SAndroid Build Coastguard Worker }
907*da0073e9SAndroid Build Coastguard Worker 
THPModule_setDeterministicFillUninitializedMemory(PyObject * _unused,PyObject * arg)908*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setDeterministicFillUninitializedMemory(
909*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
910*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
911*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
912*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
913*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg), "expected a bool, but got ", THPUtils_typename(arg));
914*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setDeterministicFillUninitializedMemory(arg == Py_True);
915*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
916*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
917*da0073e9SAndroid Build Coastguard Worker }
918*da0073e9SAndroid Build Coastguard Worker 
THPModule_deterministicFillUninitializedMemory(PyObject * _unused,PyObject * noargs)919*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_deterministicFillUninitializedMemory(
920*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
921*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
922*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().deterministicFillUninitializedMemory())
923*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
924*da0073e9SAndroid Build Coastguard Worker   else
925*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
926*da0073e9SAndroid Build Coastguard Worker }
927*da0073e9SAndroid Build Coastguard Worker 
THPModule_setUserEnabledNNPACK(PyObject * _unused,PyObject * arg)928*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setUserEnabledNNPACK(PyObject* _unused, PyObject* arg) {
929*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
930*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
931*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
932*da0073e9SAndroid Build Coastguard Worker       "set_enabled_NNPACK expects a bool, "
933*da0073e9SAndroid Build Coastguard Worker       "but got ",
934*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
935*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setUserEnabledNNPACK(arg == Py_True);
936*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
937*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
938*da0073e9SAndroid Build Coastguard Worker }
939*da0073e9SAndroid Build Coastguard Worker 
THPModule_userEnabledNNPACK(PyObject * _unused,PyObject * noargs)940*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_userEnabledNNPACK(PyObject* _unused, PyObject* noargs) {
941*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().userEnabledNNPACK())
942*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
943*da0073e9SAndroid Build Coastguard Worker   else
944*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
945*da0073e9SAndroid Build Coastguard Worker }
946*da0073e9SAndroid Build Coastguard Worker 
THPModule_setWarnAlways(PyObject * _unused,PyObject * arg)947*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) {
948*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
949*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
950*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
951*da0073e9SAndroid Build Coastguard Worker       "setWarnOnlyOnce expects a bool, "
952*da0073e9SAndroid Build Coastguard Worker       "but got ",
953*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
954*da0073e9SAndroid Build Coastguard Worker   c10::WarningUtils::set_warnAlways(arg == Py_True);
955*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
956*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
957*da0073e9SAndroid Build Coastguard Worker }
958*da0073e9SAndroid Build Coastguard Worker 
THPModule_warnAlways(PyObject * _unused,PyObject * noargs)959*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_warnAlways(PyObject* _unused, PyObject* noargs) {
960*da0073e9SAndroid Build Coastguard Worker   if (c10::WarningUtils::get_warnAlways()) {
961*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
962*da0073e9SAndroid Build Coastguard Worker   }
963*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
964*da0073e9SAndroid Build Coastguard Worker }
965*da0073e9SAndroid Build Coastguard Worker 
966*da0073e9SAndroid Build Coastguard Worker // Used only for testing C++ to Python warning translations.
THPModule_warn(PyObject * _unused,PyObject * noargs)967*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_warn(PyObject* _unused, PyObject* noargs) {
968*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
969*da0073e9SAndroid Build Coastguard Worker   TORCH_WARN("Test message for TORCH_WARN");
970*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
971*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
972*da0073e9SAndroid Build Coastguard Worker }
973*da0073e9SAndroid Build Coastguard Worker 
974*da0073e9SAndroid Build Coastguard Worker // Used only for testing C++ to Python warning translations.
THPModule_warnDeprecation(PyObject * _unused,PyObject * noargs)975*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_warnDeprecation(PyObject* _unused, PyObject* noargs) {
976*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
977*da0073e9SAndroid Build Coastguard Worker   TORCH_WARN_DEPRECATION("Test message for TORCH_WARN_DEPRECATION");
978*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
979*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
980*da0073e9SAndroid Build Coastguard Worker }
981*da0073e9SAndroid Build Coastguard Worker 
THPModule_setBenchmarkCuDNN(PyObject * _unused,PyObject * arg)982*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setBenchmarkCuDNN(PyObject* _unused, PyObject* arg) {
983*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
984*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
985*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
986*da0073e9SAndroid Build Coastguard Worker       "set_benchmark_cudnn expects a bool, "
987*da0073e9SAndroid Build Coastguard Worker       "but got ",
988*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
989*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setBenchmarkCuDNN(arg == Py_True);
990*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
991*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
992*da0073e9SAndroid Build Coastguard Worker }
993*da0073e9SAndroid Build Coastguard Worker 
THPModule_benchmarkCuDNN(PyObject * _unused,PyObject * noargs)994*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) {
995*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().benchmarkCuDNN()) {
996*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
997*da0073e9SAndroid Build Coastguard Worker   }
998*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
999*da0073e9SAndroid Build Coastguard Worker }
1000*da0073e9SAndroid Build Coastguard Worker 
THPModule_setAllowTF32CuBLAS(PyObject * _unused,PyObject * arg)1001*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setAllowTF32CuBLAS(PyObject* _unused, PyObject* arg) {
1002*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1003*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1004*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
1005*da0073e9SAndroid Build Coastguard Worker       "set_allow_tf32_cublas expects a bool, "
1006*da0073e9SAndroid Build Coastguard Worker       "but got ",
1007*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1008*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setAllowTF32CuBLAS(arg == Py_True);
1009*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1010*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1011*da0073e9SAndroid Build Coastguard Worker }
1012*da0073e9SAndroid Build Coastguard Worker 
THPModule_allowTF32CuBLAS(PyObject * _unused,PyObject * noargs)1013*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_allowTF32CuBLAS(PyObject* _unused, PyObject* noargs) {
1014*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().allowTF32CuBLAS()) {
1015*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1016*da0073e9SAndroid Build Coastguard Worker   }
1017*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
1018*da0073e9SAndroid Build Coastguard Worker }
1019*da0073e9SAndroid Build Coastguard Worker 
THPModule_setAllowFP16ReductionCuBLAS(PyObject * _unused,PyObject * arg)1020*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setAllowFP16ReductionCuBLAS(
1021*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1022*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
1023*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1024*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1025*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
1026*da0073e9SAndroid Build Coastguard Worker       "set_allow_fp16_reduction_cublas expects a bool, "
1027*da0073e9SAndroid Build Coastguard Worker       "but got ",
1028*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1029*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True);
1030*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1031*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1032*da0073e9SAndroid Build Coastguard Worker }
1033*da0073e9SAndroid Build Coastguard Worker 
THPModule_allowFP16ReductionCuBLAS(PyObject * _unused,PyObject * noargs)1034*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_allowFP16ReductionCuBLAS(
1035*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1036*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
1037*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().allowFP16ReductionCuBLAS()) {
1038*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1039*da0073e9SAndroid Build Coastguard Worker   }
1040*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
1041*da0073e9SAndroid Build Coastguard Worker }
1042*da0073e9SAndroid Build Coastguard Worker 
THPModule_setAllowBF16ReductionCuBLAS(PyObject * _unused,PyObject * arg)1043*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setAllowBF16ReductionCuBLAS(
1044*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1045*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
1046*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1047*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1048*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
1049*da0073e9SAndroid Build Coastguard Worker       "set_allow_bf16_reduction_cublas expects a bool, "
1050*da0073e9SAndroid Build Coastguard Worker       "but got ",
1051*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1052*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
1053*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1054*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1055*da0073e9SAndroid Build Coastguard Worker }
1056*da0073e9SAndroid Build Coastguard Worker 
THPModule_allowBF16ReductionCuBLAS(PyObject * _unused,PyObject * noargs)1057*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_allowBF16ReductionCuBLAS(
1058*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1059*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
1060*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().allowBF16ReductionCuBLAS()) {
1061*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1062*da0073e9SAndroid Build Coastguard Worker   }
1063*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
1064*da0073e9SAndroid Build Coastguard Worker }
1065*da0073e9SAndroid Build Coastguard Worker 
THPModule_setAllowFP16ReductionCPU(PyObject * _unused,PyObject * arg)1066*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setAllowFP16ReductionCPU(PyObject* _unused, PyObject* arg) {
1067*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1068*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1069*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
1070*da0073e9SAndroid Build Coastguard Worker       "set_allow_fp16_reduction_cpu expects a bool, "
1071*da0073e9SAndroid Build Coastguard Worker       "but got ",
1072*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1073*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setAllowFP16ReductionCPU(arg == Py_True);
1074*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1075*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1076*da0073e9SAndroid Build Coastguard Worker }
1077*da0073e9SAndroid Build Coastguard Worker 
THPModule_allowFP16ReductionCPU(PyObject * _unused,PyObject * noargs)1078*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_allowFP16ReductionCPU(PyObject* _unused, PyObject* noargs) {
1079*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().allowFP16ReductionCPU()) {
1080*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1081*da0073e9SAndroid Build Coastguard Worker   }
1082*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_FALSE;
1083*da0073e9SAndroid Build Coastguard Worker }
1084*da0073e9SAndroid Build Coastguard Worker 
THPModule_setFlushDenormal(PyObject * _unused,PyObject * arg)1085*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) {
1086*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1087*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1088*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
1089*da0073e9SAndroid Build Coastguard Worker       "flush_denormal expects a bool, "
1090*da0073e9SAndroid Build Coastguard Worker       "but got ",
1091*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1092*da0073e9SAndroid Build Coastguard Worker   if (!at::globalContext().setFlushDenormal(arg == Py_True)) {
1093*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
1094*da0073e9SAndroid Build Coastguard Worker   };
1095*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_TRUE;
1096*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1097*da0073e9SAndroid Build Coastguard Worker }
1098*da0073e9SAndroid Build Coastguard Worker 
THPModule_getDefaultDtype(PyObject * _unused,PyObject * arg)1099*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_getDefaultDtype(PyObject* _unused, PyObject* arg) {
1100*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1101*da0073e9SAndroid Build Coastguard Worker   auto scalar_type = torch::tensors::get_default_scalar_type();
1102*da0073e9SAndroid Build Coastguard Worker   return Py_NewRef(torch::getTHPDtype(scalar_type));
1103*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1104*da0073e9SAndroid Build Coastguard Worker }
1105*da0073e9SAndroid Build Coastguard Worker 
THPModule_getDefaultDevice(PyObject * _unused,PyObject * arg)1106*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_getDefaultDevice(PyObject* _unused, PyObject* arg) {
1107*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1108*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packString(c10::DeviceTypeName(
1109*da0073e9SAndroid Build Coastguard Worker       dispatchKeyToDeviceType(torch::tensors::get_default_dispatch_key()),
1110*da0073e9SAndroid Build Coastguard Worker       /*lower_case=*/true));
1111*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1112*da0073e9SAndroid Build Coastguard Worker }
1113*da0073e9SAndroid Build Coastguard Worker 
THPModule_setQEngine(PyObject *,PyObject * arg)1114*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) {
1115*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1116*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1117*da0073e9SAndroid Build Coastguard Worker       THPUtils_checkLong(arg),
1118*da0073e9SAndroid Build Coastguard Worker       "set_qengine expects an int, "
1119*da0073e9SAndroid Build Coastguard Worker       "but got ",
1120*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1121*da0073e9SAndroid Build Coastguard Worker   auto qengine = THPUtils_unpackLong(arg);
1122*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setQEngine(static_cast<at::QEngine>(qengine));
1123*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1124*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1125*da0073e9SAndroid Build Coastguard Worker }
1126*da0073e9SAndroid Build Coastguard Worker 
THPModule_qEngine(PyObject * _unused,PyObject * noargs)1127*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_qEngine(PyObject* _unused, PyObject* noargs) {
1128*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt64(
1129*da0073e9SAndroid Build Coastguard Worker       static_cast<int64_t>(at::globalContext().qEngine()));
1130*da0073e9SAndroid Build Coastguard Worker }
1131*da0073e9SAndroid Build Coastguard Worker 
THPModule_supportedQEngines(PyObject * _unused,PyObject * noargs)1132*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_supportedQEngines(PyObject* _unused, PyObject* noargs) {
1133*da0073e9SAndroid Build Coastguard Worker   auto qengines = at::globalContext().supportedQEngines();
1134*da0073e9SAndroid Build Coastguard Worker   auto list =
1135*da0073e9SAndroid Build Coastguard Worker       THPObjectPtr(PyList_New(static_cast<Py_ssize_t>(qengines.size())));
1136*da0073e9SAndroid Build Coastguard Worker   if (!list)
1137*da0073e9SAndroid Build Coastguard Worker     return nullptr;
1138*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(qengines.size())) {
1139*da0073e9SAndroid Build Coastguard Worker     PyObject* i64 = THPUtils_packInt64(static_cast<int64_t>(qengines[i]));
1140*da0073e9SAndroid Build Coastguard Worker     if (!i64)
1141*da0073e9SAndroid Build Coastguard Worker       return nullptr;
1142*da0073e9SAndroid Build Coastguard Worker     PyList_SET_ITEM(list.get(), i, i64);
1143*da0073e9SAndroid Build Coastguard Worker   }
1144*da0073e9SAndroid Build Coastguard Worker   return list.release();
1145*da0073e9SAndroid Build Coastguard Worker }
1146*da0073e9SAndroid Build Coastguard Worker 
THPModule_isEnabledXNNPACK(PyObject * _unused,PyObject * noargs)1147*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
1148*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().isXNNPACKAvailable())
1149*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1150*da0073e9SAndroid Build Coastguard Worker   else
1151*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
1152*da0073e9SAndroid Build Coastguard Worker }
1153*da0073e9SAndroid Build Coastguard Worker 
THPModule_setCheckSparseTensorInvariants(PyObject * _unused,PyObject * arg)1154*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setCheckSparseTensorInvariants(
1155*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1156*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
1157*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1158*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1159*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
1160*da0073e9SAndroid Build Coastguard Worker       "set_check_sparse_tensor_invariants expects a bool, "
1161*da0073e9SAndroid Build Coastguard Worker       "but got ",
1162*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1163*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setCheckSparseTensorInvariants(arg == Py_True);
1164*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1165*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1166*da0073e9SAndroid Build Coastguard Worker }
1167*da0073e9SAndroid Build Coastguard Worker 
THPModule_checkSparseTensorInvariants(PyObject * _unused,PyObject * noargs)1168*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_checkSparseTensorInvariants(
1169*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1170*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
1171*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().checkSparseTensorInvariants())
1172*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1173*da0073e9SAndroid Build Coastguard Worker   else
1174*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
1175*da0073e9SAndroid Build Coastguard Worker }
1176*da0073e9SAndroid Build Coastguard Worker 
THPModule_willEngineExecuteNode(PyObject * _unused,PyObject * arg)1177*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
1178*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1179*da0073e9SAndroid Build Coastguard Worker   bool isTHPFunction = THPFunction_Check(arg);
1180*da0073e9SAndroid Build Coastguard Worker   bool isTHPCppFunction = torch::autograd::THPCppFunction_Check(arg);
1181*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1182*da0073e9SAndroid Build Coastguard Worker       isTHPFunction || isTHPCppFunction,
1183*da0073e9SAndroid Build Coastguard Worker       "_will_engine_execute_node expects an grad_fn, "
1184*da0073e9SAndroid Build Coastguard Worker       "but got ",
1185*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1186*da0073e9SAndroid Build Coastguard Worker   const auto exec_info = torch::autograd::get_current_graph_task_exec_info();
1187*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1188*da0073e9SAndroid Build Coastguard Worker       exec_info,
1189*da0073e9SAndroid Build Coastguard Worker       "_get_should_execute_nodes should only be called during the backward pass");
1190*da0073e9SAndroid Build Coastguard Worker   torch::autograd::Node* node = nullptr;
1191*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<torch::autograd::Node> node_sp;
1192*da0073e9SAndroid Build Coastguard Worker   if (isTHPFunction) {
1193*da0073e9SAndroid Build Coastguard Worker     node_sp = ((THPFunction*)arg)->cdata.lock();
1194*da0073e9SAndroid Build Coastguard Worker     node = node_sp.get();
1195*da0073e9SAndroid Build Coastguard Worker   } else {
1196*da0073e9SAndroid Build Coastguard Worker     node = ((torch::autograd::THPCppFunction*)arg)->cdata.get();
1197*da0073e9SAndroid Build Coastguard Worker   }
1198*da0073e9SAndroid Build Coastguard Worker   const auto nodes_in_graph =
1199*da0073e9SAndroid Build Coastguard Worker       torch::autograd::get_current_graph_task_nodes_in_graph();
1200*da0073e9SAndroid Build Coastguard Worker   bool ret = nodes_in_graph->find(node) != nodes_in_graph->end();
1201*da0073e9SAndroid Build Coastguard Worker   if (ret && !exec_info->empty()) {
1202*da0073e9SAndroid Build Coastguard Worker     auto it = exec_info->find(node);
1203*da0073e9SAndroid Build Coastguard Worker     if (it == exec_info->end() || !it->second.should_execute()) {
1204*da0073e9SAndroid Build Coastguard Worker       ret = false;
1205*da0073e9SAndroid Build Coastguard Worker     } else {
1206*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
1207*da0073e9SAndroid Build Coastguard Worker           !(node->topological_nr() == 0 && it->second.captures_),
1208*da0073e9SAndroid Build Coastguard Worker           "A leaf node was passed to _will_engine_execute_node but we are "
1209*da0073e9SAndroid Build Coastguard Worker           "currently running autograd.grad(). This is currently not supported.");
1210*da0073e9SAndroid Build Coastguard Worker     }
1211*da0073e9SAndroid Build Coastguard Worker   }
1212*da0073e9SAndroid Build Coastguard Worker   if (ret) {
1213*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1214*da0073e9SAndroid Build Coastguard Worker   } else {
1215*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
1216*da0073e9SAndroid Build Coastguard Worker   }
1217*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1218*da0073e9SAndroid Build Coastguard Worker }
1219*da0073e9SAndroid Build Coastguard Worker 
THPModule_getCurrentGraphTaskExecutionOrder(PyObject * _unused,PyObject * noargs)1220*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_getCurrentGraphTaskExecutionOrder(
1221*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1222*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
1223*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1224*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::autograd::Node*> nodes =
1225*da0073e9SAndroid Build Coastguard Worker       torch::autograd::get_current_graph_task_execution_order();
1226*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1227*da0073e9SAndroid Build Coastguard Worker       !nodes.empty(),
1228*da0073e9SAndroid Build Coastguard Worker       "_current_graph_task_execution_order should only be called during the backward pass");
1229*da0073e9SAndroid Build Coastguard Worker   auto list = THPObjectPtr(PyList_New(static_cast<Py_ssize_t>(nodes.size())));
1230*da0073e9SAndroid Build Coastguard Worker   if (!list)
1231*da0073e9SAndroid Build Coastguard Worker     return nullptr;
1232*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(nodes.size())) {
1233*da0073e9SAndroid Build Coastguard Worker     // This node is guaranteed to be alive since the backward is still running
1234*da0073e9SAndroid Build Coastguard Worker     PyObject* pyobj_node =
1235*da0073e9SAndroid Build Coastguard Worker         torch::autograd::functionToPyObject(nodes[i]->getptr());
1236*da0073e9SAndroid Build Coastguard Worker     PyList_SET_ITEM(list.get(), i, pyobj_node);
1237*da0073e9SAndroid Build Coastguard Worker   }
1238*da0073e9SAndroid Build Coastguard Worker   return list.release();
1239*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1240*da0073e9SAndroid Build Coastguard Worker }
1241*da0073e9SAndroid Build Coastguard Worker 
THPModule_getCurrentGraphTaskId(PyObject * _unused,PyObject * noargs)1242*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
1243*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1244*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
1245*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1246*da0073e9SAndroid Build Coastguard Worker }
1247*da0073e9SAndroid Build Coastguard Worker 
THPModule_getCurrentNode(PyObject * _unused,PyObject * noargs)1248*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_getCurrentNode(PyObject* _unused, PyObject* noargs) {
1249*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1250*da0073e9SAndroid Build Coastguard Worker   return torch::autograd::functionToPyObject(
1251*da0073e9SAndroid Build Coastguard Worker       torch::autograd::get_current_node());
1252*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1253*da0073e9SAndroid Build Coastguard Worker }
1254*da0073e9SAndroid Build Coastguard Worker 
THPModule_setDefaultMobileCPUAllocator(PyObject * _unused,PyObject * noargs)1255*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_setDefaultMobileCPUAllocator(
1256*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1257*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
1258*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1259*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setDefaultMobileCPUAllocator();
1260*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1261*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1262*da0073e9SAndroid Build Coastguard Worker }
1263*da0073e9SAndroid Build Coastguard Worker 
THPModule_unsetDefaultMobileCPUAllocator(PyObject * _unused,PyObject * noargs)1264*da0073e9SAndroid Build Coastguard Worker PyObject* THPModule_unsetDefaultMobileCPUAllocator(
1265*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1266*da0073e9SAndroid Build Coastguard Worker     PyObject* noargs) {
1267*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1268*da0073e9SAndroid Build Coastguard Worker   at::globalContext().unsetDefaultMobileCPUAllocator();
1269*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1270*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1271*da0073e9SAndroid Build Coastguard Worker }
1272*da0073e9SAndroid Build Coastguard Worker 
THPModule_vmapmode_increment_nesting(PyObject * _unused,PyObject * arg)1273*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_vmapmode_increment_nesting(
1274*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1275*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
1276*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1277*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt64(at::impl::VmapMode::increment_nesting());
1278*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1279*da0073e9SAndroid Build Coastguard Worker }
1280*da0073e9SAndroid Build Coastguard Worker 
THPModule_vmapmode_decrement_nesting(PyObject * _unused,PyObject * arg)1281*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_vmapmode_decrement_nesting(
1282*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1283*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
1284*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1285*da0073e9SAndroid Build Coastguard Worker   return THPUtils_packInt64(at::impl::VmapMode::decrement_nesting());
1286*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1287*da0073e9SAndroid Build Coastguard Worker }
1288*da0073e9SAndroid Build Coastguard Worker 
THPModule_set_display_vmap_fallback_warnings_mode(PyObject * _unused,PyObject * arg)1289*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_set_display_vmap_fallback_warnings_mode(
1290*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1291*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
1292*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1293*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
1294*da0073e9SAndroid Build Coastguard Worker       PyBool_Check(arg),
1295*da0073e9SAndroid Build Coastguard Worker       "enabled must be a bool, "
1296*da0073e9SAndroid Build Coastguard Worker       "but got ",
1297*da0073e9SAndroid Build Coastguard Worker       THPUtils_typename(arg));
1298*da0073e9SAndroid Build Coastguard Worker   at::globalContext().setDisplayVmapFallbackWarnings(arg == Py_True);
1299*da0073e9SAndroid Build Coastguard Worker   Py_RETURN_NONE;
1300*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1301*da0073e9SAndroid Build Coastguard Worker }
1302*da0073e9SAndroid Build Coastguard Worker 
THPModule_are_vmap_fallback_warnings_enabled(PyObject * _unused,PyObject * arg)1303*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
1304*da0073e9SAndroid Build Coastguard Worker     PyObject* _unused,
1305*da0073e9SAndroid Build Coastguard Worker     PyObject* arg) {
1306*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1307*da0073e9SAndroid Build Coastguard Worker   if (at::globalContext().areVmapFallbackWarningsEnabled()) {
1308*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_TRUE;
1309*da0073e9SAndroid Build Coastguard Worker   } else {
1310*da0073e9SAndroid Build Coastguard Worker     Py_RETURN_FALSE;
1311*da0073e9SAndroid Build Coastguard Worker   }
1312*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
1313*da0073e9SAndroid Build Coastguard Worker }
1314*da0073e9SAndroid Build Coastguard Worker 
1315*da0073e9SAndroid Build Coastguard Worker static PyMethodDef TorchMethods[] = { // NOLINT
1316*da0073e9SAndroid Build Coastguard Worker     {"_initExtension", THPModule_initExtension, METH_O, nullptr},
1317*da0073e9SAndroid Build Coastguard Worker     {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
1318*da0073e9SAndroid Build Coastguard Worker     {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr},
1319*da0073e9SAndroid Build Coastguard Worker     {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr},
1320*da0073e9SAndroid Build Coastguard Worker     {"_init_names", THPModule_initNames, METH_O, nullptr},
1321*da0073e9SAndroid Build Coastguard Worker     {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr},
1322*da0073e9SAndroid Build Coastguard Worker     {"_set_default_tensor_type",
1323*da0073e9SAndroid Build Coastguard Worker      THPModule_setDefaultTensorType,
1324*da0073e9SAndroid Build Coastguard Worker      METH_O,
1325*da0073e9SAndroid Build Coastguard Worker      nullptr},
1326*da0073e9SAndroid Build Coastguard Worker     {"_set_default_dtype", THPModule_setDefaultDtype, METH_O, nullptr},
1327*da0073e9SAndroid Build Coastguard Worker     {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr},
1328*da0073e9SAndroid Build Coastguard Worker     {"_abort", THPModule_abort, METH_NOARGS, nullptr},
1329*da0073e9SAndroid Build Coastguard Worker     {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr},
1330*da0073e9SAndroid Build Coastguard Worker     {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr},
1331*da0073e9SAndroid Build Coastguard Worker     {"_crash_if_vptr_ubsan", THPModule_crashIfvptrUBSAN, METH_NOARGS, nullptr},
1332*da0073e9SAndroid Build Coastguard Worker     {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr},
1333*da0073e9SAndroid Build Coastguard Worker     {"_crash_if_debug_asserts_fail",
1334*da0073e9SAndroid Build Coastguard Worker      THPModule_crashIfDebugAssertsFail,
1335*da0073e9SAndroid Build Coastguard Worker      METH_O,
1336*da0073e9SAndroid Build Coastguard Worker      nullptr},
1337*da0073e9SAndroid Build Coastguard Worker     {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr},
1338*da0073e9SAndroid Build Coastguard Worker     {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr},
1339*da0073e9SAndroid Build Coastguard Worker     {"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr},
1340*da0073e9SAndroid Build Coastguard Worker     {"_get_cpu_capability", THPModule_getCpuCapability, METH_NOARGS, nullptr},
1341*da0073e9SAndroid Build Coastguard Worker     {"_set_backcompat_broadcast_warn",
1342*da0073e9SAndroid Build Coastguard Worker      THPModule_setBackcompatBroadcastWarn,
1343*da0073e9SAndroid Build Coastguard Worker      METH_O,
1344*da0073e9SAndroid Build Coastguard Worker      nullptr},
1345*da0073e9SAndroid Build Coastguard Worker     {"_get_backcompat_broadcast_warn",
1346*da0073e9SAndroid Build Coastguard Worker      THPModule_getBackcompatBroadcastWarn,
1347*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1348*da0073e9SAndroid Build Coastguard Worker      nullptr},
1349*da0073e9SAndroid Build Coastguard Worker     {"_set_backcompat_keepdim_warn",
1350*da0073e9SAndroid Build Coastguard Worker      THPModule_setBackcompatKeepdimWarn,
1351*da0073e9SAndroid Build Coastguard Worker      METH_O,
1352*da0073e9SAndroid Build Coastguard Worker      nullptr},
1353*da0073e9SAndroid Build Coastguard Worker     {"_get_backcompat_keepdim_warn",
1354*da0073e9SAndroid Build Coastguard Worker      THPModule_getBackcompatKeepdimWarn,
1355*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1356*da0073e9SAndroid Build Coastguard Worker      nullptr},
1357*da0073e9SAndroid Build Coastguard Worker     {"get_num_threads", THPModule_getNumThreads, METH_NOARGS, nullptr},
1358*da0073e9SAndroid Build Coastguard Worker     {"set_num_threads", THPModule_setNumThreads, METH_O, nullptr},
1359*da0073e9SAndroid Build Coastguard Worker     {"get_num_interop_threads",
1360*da0073e9SAndroid Build Coastguard Worker      THPModule_getNumInteropThreads,
1361*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1362*da0073e9SAndroid Build Coastguard Worker      nullptr},
1363*da0073e9SAndroid Build Coastguard Worker     {"set_num_interop_threads",
1364*da0073e9SAndroid Build Coastguard Worker      THPModule_setNumInteropThreads,
1365*da0073e9SAndroid Build Coastguard Worker      METH_O,
1366*da0073e9SAndroid Build Coastguard Worker      nullptr},
1367*da0073e9SAndroid Build Coastguard Worker     {"_get_flash_sdp_enabled",
1368*da0073e9SAndroid Build Coastguard Worker      THPModule_userEnabledFlashSDP,
1369*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1370*da0073e9SAndroid Build Coastguard Worker      nullptr},
1371*da0073e9SAndroid Build Coastguard Worker     {"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
1372*da0073e9SAndroid Build Coastguard Worker     {"_get_mem_efficient_sdp_enabled",
1373*da0073e9SAndroid Build Coastguard Worker      userEnabledMemEfficientSDP,
1374*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1375*da0073e9SAndroid Build Coastguard Worker      nullptr},
1376*da0073e9SAndroid Build Coastguard Worker     {"_set_sdp_use_mem_efficient",
1377*da0073e9SAndroid Build Coastguard Worker      THPModule_setSDPUseMemEfficient,
1378*da0073e9SAndroid Build Coastguard Worker      METH_O,
1379*da0073e9SAndroid Build Coastguard Worker      nullptr},
1380*da0073e9SAndroid Build Coastguard Worker     {"_get_math_sdp_enabled",
1381*da0073e9SAndroid Build Coastguard Worker      THPModule_userEnabledMathSDP,
1382*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1383*da0073e9SAndroid Build Coastguard Worker      nullptr},
1384*da0073e9SAndroid Build Coastguard Worker     {"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
1385*da0073e9SAndroid Build Coastguard Worker     {"_get_math_sdp_allow_fp16_bf16_reduction",
1386*da0073e9SAndroid Build Coastguard Worker      THPModule_allowFP16BF16ReductionMathSDP,
1387*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1388*da0073e9SAndroid Build Coastguard Worker      nullptr},
1389*da0073e9SAndroid Build Coastguard Worker     {"_set_math_sdp_allow_fp16_bf16_reduction",
1390*da0073e9SAndroid Build Coastguard Worker      THPModule_setAllowFP16BF16ReductionMathSDP,
1391*da0073e9SAndroid Build Coastguard Worker      METH_O,
1392*da0073e9SAndroid Build Coastguard Worker      nullptr},
1393*da0073e9SAndroid Build Coastguard Worker     {"_get_overrideable_sdp_enabled",
1394*da0073e9SAndroid Build Coastguard Worker      THPModule_userEnabledOverrideableSDP,
1395*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1396*da0073e9SAndroid Build Coastguard Worker      nullptr},
1397*da0073e9SAndroid Build Coastguard Worker     {"_set_sdp_use_overrideable",
1398*da0073e9SAndroid Build Coastguard Worker      THPModule_setSDPUseOverrideable,
1399*da0073e9SAndroid Build Coastguard Worker      METH_O,
1400*da0073e9SAndroid Build Coastguard Worker      nullptr},
1401*da0073e9SAndroid Build Coastguard Worker     {"_get_cudnn_sdp_enabled",
1402*da0073e9SAndroid Build Coastguard Worker      THPModule_userEnabledCuDNNSDP,
1403*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1404*da0073e9SAndroid Build Coastguard Worker      nullptr},
1405*da0073e9SAndroid Build Coastguard Worker     {"_set_sdp_use_cudnn", THPModule_setSDPUseCuDNN, METH_O, nullptr},
1406*da0073e9SAndroid Build Coastguard Worker     {"_get_cudnn_enabled", THPModule_userEnabledCuDNN, METH_NOARGS, nullptr},
1407*da0073e9SAndroid Build Coastguard Worker     {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr},
1408*da0073e9SAndroid Build Coastguard Worker     {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr},
1409*da0073e9SAndroid Build Coastguard Worker     {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr},
1410*da0073e9SAndroid Build Coastguard Worker     {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr},
1411*da0073e9SAndroid Build Coastguard Worker     {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr},
1412*da0073e9SAndroid Build Coastguard Worker     {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
1413*da0073e9SAndroid Build Coastguard Worker     {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
1414*da0073e9SAndroid Build Coastguard Worker     {"_get_cudnn_deterministic",
1415*da0073e9SAndroid Build Coastguard Worker      THPModule_deterministicCuDNN,
1416*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1417*da0073e9SAndroid Build Coastguard Worker      nullptr},
1418*da0073e9SAndroid Build Coastguard Worker     {"_set_cudnn_deterministic",
1419*da0073e9SAndroid Build Coastguard Worker      THPModule_setDeterministicCuDNN,
1420*da0073e9SAndroid Build Coastguard Worker      METH_O,
1421*da0073e9SAndroid Build Coastguard Worker      nullptr},
1422*da0073e9SAndroid Build Coastguard Worker     {"_get_mkldnn_deterministic",
1423*da0073e9SAndroid Build Coastguard Worker      THPModule_deterministicMkldnn,
1424*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1425*da0073e9SAndroid Build Coastguard Worker      nullptr},
1426*da0073e9SAndroid Build Coastguard Worker     {"_set_mkldnn_deterministic",
1427*da0073e9SAndroid Build Coastguard Worker      THPModule_setDeterministicMkldnn,
1428*da0073e9SAndroid Build Coastguard Worker      METH_O,
1429*da0073e9SAndroid Build Coastguard Worker      nullptr},
1430*da0073e9SAndroid Build Coastguard Worker     {"_get_deterministic_algorithms",
1431*da0073e9SAndroid Build Coastguard Worker      THPModule_deterministicAlgorithms,
1432*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1433*da0073e9SAndroid Build Coastguard Worker      nullptr},
1434*da0073e9SAndroid Build Coastguard Worker     {"_get_deterministic_algorithms_warn_only",
1435*da0073e9SAndroid Build Coastguard Worker      THPModule_deterministicAlgorithmsWarnOnly,
1436*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1437*da0073e9SAndroid Build Coastguard Worker      nullptr},
1438*da0073e9SAndroid Build Coastguard Worker     {"_set_deterministic_algorithms",
1439*da0073e9SAndroid Build Coastguard Worker      castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms),
1440*da0073e9SAndroid Build Coastguard Worker      METH_VARARGS | METH_KEYWORDS,
1441*da0073e9SAndroid Build Coastguard Worker      nullptr},
1442*da0073e9SAndroid Build Coastguard Worker     {"_get_deterministic_fill_uninitialized_memory",
1443*da0073e9SAndroid Build Coastguard Worker      THPModule_deterministicFillUninitializedMemory,
1444*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1445*da0073e9SAndroid Build Coastguard Worker      nullptr},
1446*da0073e9SAndroid Build Coastguard Worker     {"_set_deterministic_fill_uninitialized_memory",
1447*da0073e9SAndroid Build Coastguard Worker      THPModule_setDeterministicFillUninitializedMemory,
1448*da0073e9SAndroid Build Coastguard Worker      METH_O,
1449*da0073e9SAndroid Build Coastguard Worker      nullptr},
1450*da0073e9SAndroid Build Coastguard Worker     {"_get_nnpack_enabled", THPModule_userEnabledNNPACK, METH_NOARGS, nullptr},
1451*da0073e9SAndroid Build Coastguard Worker     {"_set_nnpack_enabled", THPModule_setUserEnabledNNPACK, METH_O, nullptr},
1452*da0073e9SAndroid Build Coastguard Worker     {"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
1453*da0073e9SAndroid Build Coastguard Worker     {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
1454*da0073e9SAndroid Build Coastguard Worker     {"_warn", THPModule_warn, METH_NOARGS, nullptr},
1455*da0073e9SAndroid Build Coastguard Worker     {"_warn_deprecation", THPModule_warnDeprecation, METH_NOARGS, nullptr},
1456*da0073e9SAndroid Build Coastguard Worker     {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr},
1457*da0073e9SAndroid Build Coastguard Worker     {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr},
1458*da0073e9SAndroid Build Coastguard Worker     {"_get_float32_matmul_precision",
1459*da0073e9SAndroid Build Coastguard Worker      THPModule_float32MatmulPrecision,
1460*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1461*da0073e9SAndroid Build Coastguard Worker      nullptr},
1462*da0073e9SAndroid Build Coastguard Worker     {"_set_float32_matmul_precision",
1463*da0073e9SAndroid Build Coastguard Worker      THPModule_setFloat32MatmulPrecision,
1464*da0073e9SAndroid Build Coastguard Worker      METH_O,
1465*da0073e9SAndroid Build Coastguard Worker      nullptr},
1466*da0073e9SAndroid Build Coastguard Worker     {"_get_cublas_allow_fp16_reduced_precision_reduction",
1467*da0073e9SAndroid Build Coastguard Worker      THPModule_allowFP16ReductionCuBLAS,
1468*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1469*da0073e9SAndroid Build Coastguard Worker      nullptr},
1470*da0073e9SAndroid Build Coastguard Worker     {"_set_cublas_allow_fp16_reduced_precision_reduction",
1471*da0073e9SAndroid Build Coastguard Worker      THPModule_setAllowFP16ReductionCuBLAS,
1472*da0073e9SAndroid Build Coastguard Worker      METH_O,
1473*da0073e9SAndroid Build Coastguard Worker      nullptr},
1474*da0073e9SAndroid Build Coastguard Worker     {"_get_cublas_allow_bf16_reduced_precision_reduction",
1475*da0073e9SAndroid Build Coastguard Worker      THPModule_allowBF16ReductionCuBLAS,
1476*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1477*da0073e9SAndroid Build Coastguard Worker      nullptr},
1478*da0073e9SAndroid Build Coastguard Worker     {"_set_cublas_allow_bf16_reduced_precision_reduction",
1479*da0073e9SAndroid Build Coastguard Worker      THPModule_setAllowBF16ReductionCuBLAS,
1480*da0073e9SAndroid Build Coastguard Worker      METH_O,
1481*da0073e9SAndroid Build Coastguard Worker      nullptr},
1482*da0073e9SAndroid Build Coastguard Worker     {"_get_cpu_allow_fp16_reduced_precision_reduction",
1483*da0073e9SAndroid Build Coastguard Worker      THPModule_allowFP16ReductionCPU,
1484*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1485*da0073e9SAndroid Build Coastguard Worker      nullptr},
1486*da0073e9SAndroid Build Coastguard Worker     {"_set_cpu_allow_fp16_reduced_precision_reduction",
1487*da0073e9SAndroid Build Coastguard Worker      THPModule_setAllowFP16ReductionCPU,
1488*da0073e9SAndroid Build Coastguard Worker      METH_O,
1489*da0073e9SAndroid Build Coastguard Worker      nullptr},
1490*da0073e9SAndroid Build Coastguard Worker     {"_vmapmode_increment_nesting",
1491*da0073e9SAndroid Build Coastguard Worker      THPModule_vmapmode_increment_nesting,
1492*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1493*da0073e9SAndroid Build Coastguard Worker      nullptr},
1494*da0073e9SAndroid Build Coastguard Worker     {"_vmapmode_decrement_nesting",
1495*da0073e9SAndroid Build Coastguard Worker      THPModule_vmapmode_decrement_nesting,
1496*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1497*da0073e9SAndroid Build Coastguard Worker      nullptr},
1498*da0073e9SAndroid Build Coastguard Worker     {"_debug_only_display_vmap_fallback_warnings",
1499*da0073e9SAndroid Build Coastguard Worker      THPModule_set_display_vmap_fallback_warnings_mode,
1500*da0073e9SAndroid Build Coastguard Worker      METH_O,
1501*da0073e9SAndroid Build Coastguard Worker      nullptr},
1502*da0073e9SAndroid Build Coastguard Worker     {"_debug_only_are_vmap_fallback_warnings_enabled",
1503*da0073e9SAndroid Build Coastguard Worker      THPModule_are_vmap_fallback_warnings_enabled,
1504*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1505*da0073e9SAndroid Build Coastguard Worker      nullptr},
1506*da0073e9SAndroid Build Coastguard Worker     {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr},
1507*da0073e9SAndroid Build Coastguard Worker     {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr},
1508*da0073e9SAndroid Build Coastguard Worker     {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr},
1509*da0073e9SAndroid Build Coastguard Worker     {"_rename_privateuse1_backend",
1510*da0073e9SAndroid Build Coastguard Worker      THModule_rename_privateuse1_backend,
1511*da0073e9SAndroid Build Coastguard Worker      METH_O,
1512*da0073e9SAndroid Build Coastguard Worker      nullptr},
1513*da0073e9SAndroid Build Coastguard Worker     {"_get_privateuse1_backend_name",
1514*da0073e9SAndroid Build Coastguard Worker      THModule_get_privateuse1_backend_name,
1515*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1516*da0073e9SAndroid Build Coastguard Worker      nullptr},
1517*da0073e9SAndroid Build Coastguard Worker     {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr},
1518*da0073e9SAndroid Build Coastguard Worker     {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr},
1519*da0073e9SAndroid Build Coastguard Worker     {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr},
1520*da0073e9SAndroid Build Coastguard Worker     {"_get_qengine", THPModule_qEngine, METH_NOARGS, nullptr},
1521*da0073e9SAndroid Build Coastguard Worker     {"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
1522*da0073e9SAndroid Build Coastguard Worker     {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
1523*da0073e9SAndroid Build Coastguard Worker     {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
1524*da0073e9SAndroid Build Coastguard Worker     {"_set_check_sparse_tensor_invariants",
1525*da0073e9SAndroid Build Coastguard Worker      THPModule_setCheckSparseTensorInvariants,
1526*da0073e9SAndroid Build Coastguard Worker      METH_O,
1527*da0073e9SAndroid Build Coastguard Worker      nullptr},
1528*da0073e9SAndroid Build Coastguard Worker     {"_check_sparse_tensor_invariants",
1529*da0073e9SAndroid Build Coastguard Worker      THPModule_checkSparseTensorInvariants,
1530*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1531*da0073e9SAndroid Build Coastguard Worker      nullptr},
1532*da0073e9SAndroid Build Coastguard Worker     {"_will_engine_execute_node",
1533*da0073e9SAndroid Build Coastguard Worker      THPModule_willEngineExecuteNode,
1534*da0073e9SAndroid Build Coastguard Worker      METH_O,
1535*da0073e9SAndroid Build Coastguard Worker      nullptr},
1536*da0073e9SAndroid Build Coastguard Worker     {"_current_graph_task_execution_order",
1537*da0073e9SAndroid Build Coastguard Worker      THPModule_getCurrentGraphTaskExecutionOrder,
1538*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1539*da0073e9SAndroid Build Coastguard Worker      nullptr},
1540*da0073e9SAndroid Build Coastguard Worker     {"_current_graph_task_id",
1541*da0073e9SAndroid Build Coastguard Worker      THPModule_getCurrentGraphTaskId,
1542*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1543*da0073e9SAndroid Build Coastguard Worker      nullptr},
1544*da0073e9SAndroid Build Coastguard Worker     {"_current_autograd_node", THPModule_getCurrentNode, METH_NOARGS, nullptr},
1545*da0073e9SAndroid Build Coastguard Worker     {"_set_default_mobile_cpu_allocator",
1546*da0073e9SAndroid Build Coastguard Worker      THPModule_setDefaultMobileCPUAllocator,
1547*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1548*da0073e9SAndroid Build Coastguard Worker      nullptr},
1549*da0073e9SAndroid Build Coastguard Worker     {"_unset_default_mobile_cpu_allocator",
1550*da0073e9SAndroid Build Coastguard Worker      THPModule_unsetDefaultMobileCPUAllocator,
1551*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1552*da0073e9SAndroid Build Coastguard Worker      nullptr},
1553*da0073e9SAndroid Build Coastguard Worker     {"_is_torch_function_enabled",
1554*da0073e9SAndroid Build Coastguard Worker      THPModule_isEnabledTorchFunction,
1555*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1556*da0073e9SAndroid Build Coastguard Worker      nullptr},
1557*da0073e9SAndroid Build Coastguard Worker     {"_is_torch_function_all_disabled",
1558*da0073e9SAndroid Build Coastguard Worker      THPModule_isAllDisabledTorchFunction,
1559*da0073e9SAndroid Build Coastguard Worker      METH_NOARGS,
1560*da0073e9SAndroid Build Coastguard Worker      nullptr},
1561*da0073e9SAndroid Build Coastguard Worker     {"_disabled_torch_function_impl",
1562*da0073e9SAndroid Build Coastguard Worker      THPModule_disable_torch_function,
1563*da0073e9SAndroid Build Coastguard Worker      METH_VARARGS,
1564*da0073e9SAndroid Build Coastguard Worker      nullptr},
1565*da0073e9SAndroid Build Coastguard Worker     {"_disabled_torch_dispatch_impl",
1566*da0073e9SAndroid Build Coastguard Worker      THPModule_disable_torch_dispatch,
1567*da0073e9SAndroid Build Coastguard Worker      METH_VARARGS,
1568*da0073e9SAndroid Build Coastguard Worker      nullptr},
1569*da0073e9SAndroid Build Coastguard Worker     {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
1570*da0073e9SAndroid Build Coastguard Worker     {"_has_torch_function_unary",
1571*da0073e9SAndroid Build Coastguard Worker      THPModule_has_torch_function_unary,
1572*da0073e9SAndroid Build Coastguard Worker      METH_O,
1573*da0073e9SAndroid Build Coastguard Worker      nullptr},
1574*da0073e9SAndroid Build Coastguard Worker     {"_has_torch_function_variadic",
1575*da0073e9SAndroid Build Coastguard Worker      (PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
1576*da0073e9SAndroid Build Coastguard Worker      METH_FASTCALL,
1577*da0073e9SAndroid Build Coastguard Worker      nullptr},
1578*da0073e9SAndroid Build Coastguard Worker     {nullptr, nullptr, 0, nullptr}};
1579*da0073e9SAndroid Build Coastguard Worker 
1580*da0073e9SAndroid Build Coastguard Worker void THCPStream_init(PyObject* module);
1581*da0073e9SAndroid Build Coastguard Worker void THCPEvent_init(PyObject* module);
1582*da0073e9SAndroid Build Coastguard Worker void THCPGraph_init(PyObject* module);
1583*da0073e9SAndroid Build Coastguard Worker void THCPMemPool_init(PyObject* module);
1584*da0073e9SAndroid Build Coastguard Worker 
1585*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
1586*da0073e9SAndroid Build Coastguard Worker PyMethodDef* THCPModule_methods();
1587*da0073e9SAndroid Build Coastguard Worker namespace torch::cuda {
1588*da0073e9SAndroid Build Coastguard Worker void initModule(PyObject* module);
1589*da0073e9SAndroid Build Coastguard Worker } // namespace torch::cuda
1590*da0073e9SAndroid Build Coastguard Worker #endif
1591*da0073e9SAndroid Build Coastguard Worker 
1592*da0073e9SAndroid Build Coastguard Worker #ifdef USE_XPU
1593*da0073e9SAndroid Build Coastguard Worker PyMethodDef* THXPModule_methods();
1594*da0073e9SAndroid Build Coastguard Worker void THXPStream_init(PyObject* module);
1595*da0073e9SAndroid Build Coastguard Worker void THXPEvent_init(PyObject* module);
1596*da0073e9SAndroid Build Coastguard Worker namespace torch::xpu {
1597*da0073e9SAndroid Build Coastguard Worker void initModule(PyObject* module);
1598*da0073e9SAndroid Build Coastguard Worker } // namespace torch::xpu
1599*da0073e9SAndroid Build Coastguard Worker #endif
1600*da0073e9SAndroid Build Coastguard Worker 
1601*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ITT
1602*da0073e9SAndroid Build Coastguard Worker namespace torch::profiler {
1603*da0073e9SAndroid Build Coastguard Worker void initIttBindings(PyObject* module);
1604*da0073e9SAndroid Build Coastguard Worker } // namespace torch::profiler
1605*da0073e9SAndroid Build Coastguard Worker #endif
1606*da0073e9SAndroid Build Coastguard Worker 
1607*da0073e9SAndroid Build Coastguard Worker static std::vector<PyMethodDef> methods;
1608*da0073e9SAndroid Build Coastguard Worker 
1609*da0073e9SAndroid Build Coastguard Worker // In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
1610*da0073e9SAndroid Build Coastguard Worker // Guaranteed to be invoked from Python under GIL, no locking on map needed
LogAPIUsageOnceFromPython(const std::string & event)1611*da0073e9SAndroid Build Coastguard Worker static void LogAPIUsageOnceFromPython(const std::string& event) {
1612*da0073e9SAndroid Build Coastguard Worker   static std::unordered_set<std::string> seen;
1613*da0073e9SAndroid Build Coastguard Worker   if (!seen.count(event)) {
1614*da0073e9SAndroid Build Coastguard Worker     seen.insert(event);
1615*da0073e9SAndroid Build Coastguard Worker     c10::LogAPIUsage(event);
1616*da0073e9SAndroid Build Coastguard Worker   }
1617*da0073e9SAndroid Build Coastguard Worker }
1618*da0073e9SAndroid Build Coastguard Worker 
LogAPIUsageMetadataFromPython(const std::string & event,const std::map<std::string,std::string> & metadata_map)1619*da0073e9SAndroid Build Coastguard Worker static void LogAPIUsageMetadataFromPython(
1620*da0073e9SAndroid Build Coastguard Worker     const std::string& event,
1621*da0073e9SAndroid Build Coastguard Worker     const std::map<std::string, std::string>& metadata_map) {
1622*da0073e9SAndroid Build Coastguard Worker   c10::LogAPIUsageMetadata(event, metadata_map);
1623*da0073e9SAndroid Build Coastguard Worker }
1624*da0073e9SAndroid Build Coastguard Worker 
1625*da0073e9SAndroid Build Coastguard Worker // Weak reference to tensor, used to test a tensor isn't leaked
1626*da0073e9SAndroid Build Coastguard Worker class WeakTensorRef {
1627*da0073e9SAndroid Build Coastguard Worker   c10::weak_intrusive_ptr<c10::TensorImpl> weakref_;
1628*da0073e9SAndroid Build Coastguard Worker 
1629*da0073e9SAndroid Build Coastguard Worker  public:
WeakTensorRef(const at::Tensor & t)1630*da0073e9SAndroid Build Coastguard Worker   WeakTensorRef(const at::Tensor& t) : weakref_(t.getIntrusivePtr()) {}
1631*da0073e9SAndroid Build Coastguard Worker 
expired()1632*da0073e9SAndroid Build Coastguard Worker   bool expired() {
1633*da0073e9SAndroid Build Coastguard Worker     return weakref_.expired();
1634*da0073e9SAndroid Build Coastguard Worker   }
1635*da0073e9SAndroid Build Coastguard Worker };
1636*da0073e9SAndroid Build Coastguard Worker 
1637*da0073e9SAndroid Build Coastguard Worker extern "C" C10_EXPORT PyObject* initModule();
1638*da0073e9SAndroid Build Coastguard Worker // separate decl and defn for msvc error C2491
initModule()1639*da0073e9SAndroid Build Coastguard Worker PyObject* initModule() {
1640*da0073e9SAndroid Build Coastguard Worker   HANDLE_TH_ERRORS
1641*da0073e9SAndroid Build Coastguard Worker 
1642*da0073e9SAndroid Build Coastguard Worker   c10::initLogging();
1643*da0073e9SAndroid Build Coastguard Worker   c10::set_terminate_handler();
1644*da0073e9SAndroid Build Coastguard Worker   at::internal::lazy_init_num_threads();
1645*da0073e9SAndroid Build Coastguard Worker 
1646*da0073e9SAndroid Build Coastguard Worker   C10_LOG_API_USAGE_ONCE("torch.python.import");
1647*da0073e9SAndroid Build Coastguard Worker 
1648*da0073e9SAndroid Build Coastguard Worker #define ASSERT_TRUE(cmd) \
1649*da0073e9SAndroid Build Coastguard Worker   if (!(cmd))            \
1650*da0073e9SAndroid Build Coastguard Worker   return nullptr
1651*da0073e9SAndroid Build Coastguard Worker 
1652*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(methods, TorchMethods);
1653*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
1654*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
1655*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
1656*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
1657*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
1658*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(methods, THCPModule_methods());
1659*da0073e9SAndroid Build Coastguard Worker #endif
1660*da0073e9SAndroid Build Coastguard Worker #ifdef USE_XPU
1661*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(methods, THXPModule_methods());
1662*da0073e9SAndroid Build Coastguard Worker #endif
1663*da0073e9SAndroid Build Coastguard Worker #if defined(USE_DISTRIBUTED) && defined(USE_C10D)
1664*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(
1665*da0073e9SAndroid Build Coastguard Worker       methods, torch::distributed::c10d::python_functions());
1666*da0073e9SAndroid Build Coastguard Worker #ifndef _WIN32
1667*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(
1668*da0073e9SAndroid Build Coastguard Worker       methods, torch::distributed::rpc::python_functions());
1669*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(
1670*da0073e9SAndroid Build Coastguard Worker       methods, torch::distributed::autograd::python_functions());
1671*da0073e9SAndroid Build Coastguard Worker   THPUtils_addPyMethodDefs(
1672*da0073e9SAndroid Build Coastguard Worker       methods, torch::distributed::rpc::testing::python_functions());
1673*da0073e9SAndroid Build Coastguard Worker #endif
1674*da0073e9SAndroid Build Coastguard Worker #endif
1675*da0073e9SAndroid Build Coastguard Worker 
1676*da0073e9SAndroid Build Coastguard Worker   static struct PyModuleDef torchmodule = {
1677*da0073e9SAndroid Build Coastguard Worker       PyModuleDef_HEAD_INIT, "torch._C", nullptr, -1, methods.data()};
1678*da0073e9SAndroid Build Coastguard Worker   module = PyModule_Create(&torchmodule);
1679*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module);
1680*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(THPGenerator_init(module));
1681*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(THPException_init(module));
1682*da0073e9SAndroid Build Coastguard Worker   THPSize_init(module);
1683*da0073e9SAndroid Build Coastguard Worker   THPDtype_init(module);
1684*da0073e9SAndroid Build Coastguard Worker   THPDTypeInfo_init(module);
1685*da0073e9SAndroid Build Coastguard Worker   THPLayout_init(module);
1686*da0073e9SAndroid Build Coastguard Worker   THPMemoryFormat_init(module);
1687*da0073e9SAndroid Build Coastguard Worker   THPQScheme_init(module);
1688*da0073e9SAndroid Build Coastguard Worker   THPDevice_init(module);
1689*da0073e9SAndroid Build Coastguard Worker   THPStream_init(module);
1690*da0073e9SAndroid Build Coastguard Worker   THPEvent_init(module);
1691*da0073e9SAndroid Build Coastguard Worker   NodeBase_init(module);
1692*da0073e9SAndroid Build Coastguard Worker   NodeIter_init(module);
1693*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(THPVariable_initModule(module));
1694*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(THPFunction_initModule(module));
1695*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(THPEngine_initModule(module));
1696*da0073e9SAndroid Build Coastguard Worker   // NOTE: We need to be able to access OperatorExportTypes from ONNX for use in
1697*da0073e9SAndroid Build Coastguard Worker   // the export side of JIT, so this ONNX init needs to appear before the JIT
1698*da0073e9SAndroid Build Coastguard Worker   // init.
1699*da0073e9SAndroid Build Coastguard Worker   torch::onnx::initONNXBindings(module);
1700*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initEnumTag(module);
1701*da0073e9SAndroid Build Coastguard Worker   torch::jit::initJITBindings(module);
1702*da0073e9SAndroid Build Coastguard Worker   torch::monitor::initMonitorBindings(module);
1703*da0073e9SAndroid Build Coastguard Worker   torch::impl::dispatch::initDispatchBindings(module);
1704*da0073e9SAndroid Build Coastguard Worker   torch::dynamo::initDynamoBindings(module);
1705*da0073e9SAndroid Build Coastguard Worker   torch::functorch::impl::initFuncTorchBindings(module);
1706*da0073e9SAndroid Build Coastguard Worker   torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
1707*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initReturnTypes(module);
1708*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initNNFunctions(module);
1709*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initFFTFunctions(module);
1710*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initLinalgFunctions(module);
1711*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initNestedFunctions(module);
1712*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initSparseFunctions(module);
1713*da0073e9SAndroid Build Coastguard Worker   torch::autograd::initSpecialFunctions(module);
1714*da0073e9SAndroid Build Coastguard Worker   torch::autograd::init_legacy_variable(module);
1715*da0073e9SAndroid Build Coastguard Worker   torch::profiler::initPythonBindings(module);
1716*da0073e9SAndroid Build Coastguard Worker   torch::python::init_bindings(module);
1717*da0073e9SAndroid Build Coastguard Worker   torch::lazy::initLazyBindings(module);
1718*da0073e9SAndroid Build Coastguard Worker   torch::inductor::initAOTIRunnerBindings(module);
1719*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ITT
1720*da0073e9SAndroid Build Coastguard Worker   torch::profiler::initIttBindings(module);
1721*da0073e9SAndroid Build Coastguard Worker #endif
1722*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
1723*da0073e9SAndroid Build Coastguard Worker   torch::cuda::initModule(module);
1724*da0073e9SAndroid Build Coastguard Worker #endif
1725*da0073e9SAndroid Build Coastguard Worker #ifdef USE_XPU
1726*da0073e9SAndroid Build Coastguard Worker   torch::xpu::initModule(module);
1727*da0073e9SAndroid Build Coastguard Worker #endif
1728*da0073e9SAndroid Build Coastguard Worker   torch::mtia::initModule(module);
1729*da0073e9SAndroid Build Coastguard Worker   torch::cpu::initModule(module);
1730*da0073e9SAndroid Build Coastguard Worker   torch::instruction_counter::initModule(module);
1731*da0073e9SAndroid Build Coastguard Worker   torch::initVerboseBindings(module);
1732*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(THPStorage_init(module));
1733*da0073e9SAndroid Build Coastguard Worker 
1734*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
1735*da0073e9SAndroid Build Coastguard Worker   // This will only initialise base classes and attach them to library namespace
1736*da0073e9SAndroid Build Coastguard Worker   // They won't be ready for real usage until importing cuda module, that will
1737*da0073e9SAndroid Build Coastguard Worker   // complete the process (but it defines Python classes before calling back
1738*da0073e9SAndroid Build Coastguard Worker   // into C, so these lines have to execute first)..
1739*da0073e9SAndroid Build Coastguard Worker   THCPStream_init(module);
1740*da0073e9SAndroid Build Coastguard Worker   THCPEvent_init(module);
1741*da0073e9SAndroid Build Coastguard Worker   THCPGraph_init(module);
1742*da0073e9SAndroid Build Coastguard Worker   THCPMemPool_init(module);
1743*da0073e9SAndroid Build Coastguard Worker #endif
1744*da0073e9SAndroid Build Coastguard Worker 
1745*da0073e9SAndroid Build Coastguard Worker #ifdef USE_XPU
1746*da0073e9SAndroid Build Coastguard Worker   THXPStream_init(module);
1747*da0073e9SAndroid Build Coastguard Worker   THXPEvent_init(module);
1748*da0073e9SAndroid Build Coastguard Worker #endif
1749*da0073e9SAndroid Build Coastguard Worker 
1750*da0073e9SAndroid Build Coastguard Worker   auto set_module_attr =
1751*da0073e9SAndroid Build Coastguard Worker       [&](const char* name, PyObject* v, bool incref = true) {
1752*da0073e9SAndroid Build Coastguard Worker         // PyModule_AddObject steals reference
1753*da0073e9SAndroid Build Coastguard Worker         if (incref) {
1754*da0073e9SAndroid Build Coastguard Worker           Py_INCREF(v);
1755*da0073e9SAndroid Build Coastguard Worker         }
1756*da0073e9SAndroid Build Coastguard Worker 
1757*da0073e9SAndroid Build Coastguard Worker         int ret = PyModule_AddObject(module, name, v);
1758*da0073e9SAndroid Build Coastguard Worker         if (ret != 0) {
1759*da0073e9SAndroid Build Coastguard Worker           Py_DECREF(v);
1760*da0073e9SAndroid Build Coastguard Worker         }
1761*da0073e9SAndroid Build Coastguard Worker 
1762*da0073e9SAndroid Build Coastguard Worker         return ret == 0;
1763*da0073e9SAndroid Build Coastguard Worker       };
1764*da0073e9SAndroid Build Coastguard Worker 
1765*da0073e9SAndroid Build Coastguard Worker #if defined(USE_CUDNN) || defined(USE_ROCM)
1766*da0073e9SAndroid Build Coastguard Worker   PyObject* has_cudnn = Py_True;
1767*da0073e9SAndroid Build Coastguard Worker #else
1768*da0073e9SAndroid Build Coastguard Worker   PyObject* has_cudnn = Py_False;
1769*da0073e9SAndroid Build Coastguard Worker #endif
1770*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn));
1771*da0073e9SAndroid Build Coastguard Worker 
1772*da0073e9SAndroid Build Coastguard Worker #if defined(USE_CUSPARSELT)
1773*da0073e9SAndroid Build Coastguard Worker   PyObject* has_cusparselt = Py_True;
1774*da0073e9SAndroid Build Coastguard Worker #else
1775*da0073e9SAndroid Build Coastguard Worker   PyObject* has_cusparselt = Py_False;
1776*da0073e9SAndroid Build Coastguard Worker #endif
1777*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_has_cusparselt", has_cusparselt));
1778*da0073e9SAndroid Build Coastguard Worker 
1779*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
1780*da0073e9SAndroid Build Coastguard Worker   PyObject* has_spectral = Py_True;
1781*da0073e9SAndroid Build Coastguard Worker #else
1782*da0073e9SAndroid Build Coastguard Worker   PyObject* has_spectral = Py_False;
1783*da0073e9SAndroid Build Coastguard Worker #endif
1784*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("has_spectral", has_spectral));
1785*da0073e9SAndroid Build Coastguard Worker 
1786*da0073e9SAndroid Build Coastguard Worker   // force ATen to initialize because it handles
1787*da0073e9SAndroid Build Coastguard Worker   // setting up TH Errors so that they throw C++ exceptions
1788*da0073e9SAndroid Build Coastguard Worker   at::init();
1789*da0073e9SAndroid Build Coastguard Worker 
1790*da0073e9SAndroid Build Coastguard Worker   // Automatically translate errors thrown from pybind11 functions
1791*da0073e9SAndroid Build Coastguard Worker   py::register_exception_translator([](std::exception_ptr e) { // NOLINT
1792*da0073e9SAndroid Build Coastguard Worker     try {
1793*da0073e9SAndroid Build Coastguard Worker       if (e) {
1794*da0073e9SAndroid Build Coastguard Worker         std::rethrow_exception(e);
1795*da0073e9SAndroid Build Coastguard Worker       }
1796*da0073e9SAndroid Build Coastguard Worker     }
1797*da0073e9SAndroid Build Coastguard Worker     CATCH_TH_ERRORS()
1798*da0073e9SAndroid Build Coastguard Worker   });
1799*da0073e9SAndroid Build Coastguard Worker 
1800*da0073e9SAndroid Build Coastguard Worker   auto py_module = py::reinterpret_borrow<py::module>(module);
1801*da0073e9SAndroid Build Coastguard Worker   py_module.def("_demangle", &c10::demangle);
1802*da0073e9SAndroid Build Coastguard Worker   py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython);
1803*da0073e9SAndroid Build Coastguard Worker   py_module.def("_log_api_usage_metadata", &LogAPIUsageMetadataFromPython);
1804*da0073e9SAndroid Build Coastguard Worker 
1805*da0073e9SAndroid Build Coastguard Worker   py_module.def("vitals_enabled", &at::vitals::torchVitalEnabled);
1806*da0073e9SAndroid Build Coastguard Worker   py_module.def(
1807*da0073e9SAndroid Build Coastguard Worker       "set_vital",
1808*da0073e9SAndroid Build Coastguard Worker       [](const std::string& vital,
1809*da0073e9SAndroid Build Coastguard Worker          const std::string& attr,
1810*da0073e9SAndroid Build Coastguard Worker          const std::string& value) {
1811*da0073e9SAndroid Build Coastguard Worker         return at::vitals::VitalsAPI.setVital(vital, attr, value);
1812*da0073e9SAndroid Build Coastguard Worker       });
1813*da0073e9SAndroid Build Coastguard Worker   py_module.def(
1814*da0073e9SAndroid Build Coastguard Worker       "read_vitals", []() { return at::vitals::VitalsAPI.readVitals(); });
1815*da0073e9SAndroid Build Coastguard Worker 
1816*da0073e9SAndroid Build Coastguard Worker   py_module.def(
1817*da0073e9SAndroid Build Coastguard Worker       "init_num_threads",
1818*da0073e9SAndroid Build Coastguard Worker       torch::wrap_pybind_function(at::init_num_threads),
1819*da0073e9SAndroid Build Coastguard Worker       R"(
1820*da0073e9SAndroid Build Coastguard Worker init_num_threads()
1821*da0073e9SAndroid Build Coastguard Worker 
1822*da0073e9SAndroid Build Coastguard Worker Initializes the number of parallel threads used on the current thread.
1823*da0073e9SAndroid Build Coastguard Worker 
1824*da0073e9SAndroid Build Coastguard Worker Call this whenever a new thread is created in order to propagate values from
1825*da0073e9SAndroid Build Coastguard Worker :func:`torch.set_num_threads` onto the new thread.
1826*da0073e9SAndroid Build Coastguard Worker )");
1827*da0073e9SAndroid Build Coastguard Worker 
1828*da0073e9SAndroid Build Coastguard Worker   py_module.def("_set_cached_tensors_enabled", [](bool enabled) {
1829*da0073e9SAndroid Build Coastguard Worker     at::caching::set_cached_tensors_enabled(enabled);
1830*da0073e9SAndroid Build Coastguard Worker   });
1831*da0073e9SAndroid Build Coastguard Worker 
1832*da0073e9SAndroid Build Coastguard Worker   py_module.def("_add_cached_tensor", [](const at::Tensor& t) {
1833*da0073e9SAndroid Build Coastguard Worker     at::caching::add_cached_tensor(t);
1834*da0073e9SAndroid Build Coastguard Worker   });
1835*da0073e9SAndroid Build Coastguard Worker 
1836*da0073e9SAndroid Build Coastguard Worker   py_module.def("_remove_cached_tensor", [](const at::Tensor& t) {
1837*da0073e9SAndroid Build Coastguard Worker     at::caching::remove_cached_tensor(t);
1838*da0073e9SAndroid Build Coastguard Worker   });
1839*da0073e9SAndroid Build Coastguard Worker 
1840*da0073e9SAndroid Build Coastguard Worker   py_module.def("_is_cached_tensor", [](const at::Tensor& t) {
1841*da0073e9SAndroid Build Coastguard Worker     return at::caching::is_cached_tensor(t);
1842*da0073e9SAndroid Build Coastguard Worker   });
1843*da0073e9SAndroid Build Coastguard Worker 
1844*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
1845*da0073e9SAndroid Build Coastguard Worker       set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
1846*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
1847*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
1848*da0073e9SAndroid Build Coastguard Worker       set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
1849*da0073e9SAndroid Build Coastguard Worker 
1850*da0073e9SAndroid Build Coastguard Worker   py_module.def("_valgrind_supported_platform", []() {
1851*da0073e9SAndroid Build Coastguard Worker #if defined(USE_VALGRIND)
1852*da0073e9SAndroid Build Coastguard Worker     return true;
1853*da0073e9SAndroid Build Coastguard Worker #else
1854*da0073e9SAndroid Build Coastguard Worker       return false;
1855*da0073e9SAndroid Build Coastguard Worker #endif
1856*da0073e9SAndroid Build Coastguard Worker   });
1857*da0073e9SAndroid Build Coastguard Worker 
1858*da0073e9SAndroid Build Coastguard Worker   py_module.def("_valgrind_toggle", []() {
1859*da0073e9SAndroid Build Coastguard Worker #if defined(USE_VALGRIND)
1860*da0073e9SAndroid Build Coastguard Worker     CALLGRIND_TOGGLE_COLLECT;
1861*da0073e9SAndroid Build Coastguard Worker #else
1862*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(false, "Valgrind is not supported.");
1863*da0073e9SAndroid Build Coastguard Worker #endif
1864*da0073e9SAndroid Build Coastguard Worker   });
1865*da0073e9SAndroid Build Coastguard Worker 
1866*da0073e9SAndroid Build Coastguard Worker   py_module.def("_valgrind_toggle_and_dump_stats", []() {
1867*da0073e9SAndroid Build Coastguard Worker #if defined(USE_VALGRIND)
1868*da0073e9SAndroid Build Coastguard Worker     // NB: If we don't toggle collect around dump stats, callgrind_annotate
1869*da0073e9SAndroid Build Coastguard Worker     //     won't process the results correctly. Specifically,
1870*da0073e9SAndroid Build Coastguard Worker     //     `callgrind_annotate --inclusive=no` will be almost completely empty.
1871*da0073e9SAndroid Build Coastguard Worker     CALLGRIND_TOGGLE_COLLECT;
1872*da0073e9SAndroid Build Coastguard Worker     CALLGRIND_DUMP_STATS;
1873*da0073e9SAndroid Build Coastguard Worker #else
1874*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(false, "Valgrind is not supported.");
1875*da0073e9SAndroid Build Coastguard Worker #endif
1876*da0073e9SAndroid Build Coastguard Worker   });
1877*da0073e9SAndroid Build Coastguard Worker 
1878*da0073e9SAndroid Build Coastguard Worker   py::class_<WeakTensorRef>(py_module, "_WeakTensorRef")
1879*da0073e9SAndroid Build Coastguard Worker       .def(py::init([](py::object tensor) {
1880*da0073e9SAndroid Build Coastguard Worker         return WeakTensorRef(THPVariable_Unpack(tensor.ptr()));
1881*da0073e9SAndroid Build Coastguard Worker       }))
1882*da0073e9SAndroid Build Coastguard Worker       .def("expired", &WeakTensorRef::expired);
1883*da0073e9SAndroid Build Coastguard Worker 
1884*da0073e9SAndroid Build Coastguard Worker   py::enum_<at::native::ConvBackend>(py_module, "_ConvBackend")
1885*da0073e9SAndroid Build Coastguard Worker       .value("CudaDepthwise2d", at::native::ConvBackend::CudaDepthwise2d)
1886*da0073e9SAndroid Build Coastguard Worker       .value("CudaDepthwise3d", at::native::ConvBackend::CudaDepthwise3d)
1887*da0073e9SAndroid Build Coastguard Worker       .value("Cudnn", at::native::ConvBackend::Cudnn)
1888*da0073e9SAndroid Build Coastguard Worker       .value("CudnnTranspose", at::native::ConvBackend::CudnnTranspose)
1889*da0073e9SAndroid Build Coastguard Worker       .value("Empty", at::native::ConvBackend::Empty)
1890*da0073e9SAndroid Build Coastguard Worker       .value("Miopen", at::native::ConvBackend::Miopen)
1891*da0073e9SAndroid Build Coastguard Worker       .value("MiopenDepthwise", at::native::ConvBackend::MiopenDepthwise)
1892*da0073e9SAndroid Build Coastguard Worker       .value("MiopenTranspose", at::native::ConvBackend::MiopenTranspose)
1893*da0073e9SAndroid Build Coastguard Worker       .value("Mkldnn", at::native::ConvBackend::Mkldnn)
1894*da0073e9SAndroid Build Coastguard Worker       .value("MkldnnEmpty", at::native::ConvBackend::MkldnnEmpty)
1895*da0073e9SAndroid Build Coastguard Worker       .value("NnpackSpatial", at::native::ConvBackend::NnpackSpatial)
1896*da0073e9SAndroid Build Coastguard Worker       .value("Overrideable", at::native::ConvBackend::Overrideable)
1897*da0073e9SAndroid Build Coastguard Worker       .value("Slow2d", at::native::ConvBackend::Slow2d)
1898*da0073e9SAndroid Build Coastguard Worker       .value("Slow3d", at::native::ConvBackend::Slow3d)
1899*da0073e9SAndroid Build Coastguard Worker       .value("SlowDilated2d", at::native::ConvBackend::SlowDilated2d)
1900*da0073e9SAndroid Build Coastguard Worker       .value("SlowDilated3d", at::native::ConvBackend::SlowDilated3d)
1901*da0073e9SAndroid Build Coastguard Worker       .value("SlowTranspose2d", at::native::ConvBackend::SlowTranspose2d)
1902*da0073e9SAndroid Build Coastguard Worker       .value("SlowTranspose3d", at::native::ConvBackend::SlowTranspose3d)
1903*da0073e9SAndroid Build Coastguard Worker       .value(
1904*da0073e9SAndroid Build Coastguard Worker           "Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise)
1905*da0073e9SAndroid Build Coastguard Worker       .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d)
1906*da0073e9SAndroid Build Coastguard Worker       .value("Mps", at::native::ConvBackend::Mps)
1907*da0073e9SAndroid Build Coastguard Worker       .value("MpsTranspose,", at::native::ConvBackend::MpsTranspose);
1908*da0073e9SAndroid Build Coastguard Worker 
1909*da0073e9SAndroid Build Coastguard Worker   py_module.def(
1910*da0073e9SAndroid Build Coastguard Worker       "_select_conv_backend",
1911*da0073e9SAndroid Build Coastguard Worker       [](const at::Tensor& input,
1912*da0073e9SAndroid Build Coastguard Worker          const at::Tensor& weight,
1913*da0073e9SAndroid Build Coastguard Worker          const std::optional<at::Tensor>& bias_opt,
1914*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef stride_,
1915*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef padding_,
1916*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef dilation_,
1917*da0073e9SAndroid Build Coastguard Worker          bool transposed_,
1918*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef output_padding_,
1919*da0073e9SAndroid Build Coastguard Worker          c10::SymInt groups_) {
1920*da0073e9SAndroid Build Coastguard Worker         return at::native::select_conv_backend(
1921*da0073e9SAndroid Build Coastguard Worker             input,
1922*da0073e9SAndroid Build Coastguard Worker             weight,
1923*da0073e9SAndroid Build Coastguard Worker             bias_opt,
1924*da0073e9SAndroid Build Coastguard Worker             stride_,
1925*da0073e9SAndroid Build Coastguard Worker             padding_,
1926*da0073e9SAndroid Build Coastguard Worker             dilation_,
1927*da0073e9SAndroid Build Coastguard Worker             transposed_,
1928*da0073e9SAndroid Build Coastguard Worker             output_padding_,
1929*da0073e9SAndroid Build Coastguard Worker             std::move(groups_),
1930*da0073e9SAndroid Build Coastguard Worker             std::nullopt);
1931*da0073e9SAndroid Build Coastguard Worker       },
1932*da0073e9SAndroid Build Coastguard Worker       py::arg("input"),
1933*da0073e9SAndroid Build Coastguard Worker       py::arg("weight"),
1934*da0073e9SAndroid Build Coastguard Worker       py::arg("bias"),
1935*da0073e9SAndroid Build Coastguard Worker       py::arg("stride"),
1936*da0073e9SAndroid Build Coastguard Worker       py::arg("padding"),
1937*da0073e9SAndroid Build Coastguard Worker       py::arg("dilation"),
1938*da0073e9SAndroid Build Coastguard Worker       py::arg("transposed"),
1939*da0073e9SAndroid Build Coastguard Worker       py::arg("output_padding"),
1940*da0073e9SAndroid Build Coastguard Worker       py::arg("groups"));
1941*da0073e9SAndroid Build Coastguard Worker 
1942*da0073e9SAndroid Build Coastguard Worker   // overload for bias_sizes_opt/backward TODO: figure out default value
1943*da0073e9SAndroid Build Coastguard Worker   py_module.def(
1944*da0073e9SAndroid Build Coastguard Worker       "_select_conv_backend",
1945*da0073e9SAndroid Build Coastguard Worker       [](const at::Tensor& input,
1946*da0073e9SAndroid Build Coastguard Worker          const at::Tensor& weight,
1947*da0073e9SAndroid Build Coastguard Worker          const std::optional<at::Tensor>& bias,
1948*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef stride_,
1949*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef padding_,
1950*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef dilation_,
1951*da0073e9SAndroid Build Coastguard Worker          bool transposed_,
1952*da0073e9SAndroid Build Coastguard Worker          at::SymIntArrayRef output_padding_,
1953*da0073e9SAndroid Build Coastguard Worker          c10::SymInt groups_,
1954*da0073e9SAndroid Build Coastguard Worker          std::optional<std::vector<c10::SymInt>> bias_sizes_opt) {
1955*da0073e9SAndroid Build Coastguard Worker         c10::OptionalArrayRef<c10::SymInt> ref = std::nullopt;
1956*da0073e9SAndroid Build Coastguard Worker         if (bias_sizes_opt) {
1957*da0073e9SAndroid Build Coastguard Worker           ref = (*bias_sizes_opt);
1958*da0073e9SAndroid Build Coastguard Worker         }
1959*da0073e9SAndroid Build Coastguard Worker         return at::native::select_conv_backend(
1960*da0073e9SAndroid Build Coastguard Worker             input,
1961*da0073e9SAndroid Build Coastguard Worker             weight,
1962*da0073e9SAndroid Build Coastguard Worker             bias,
1963*da0073e9SAndroid Build Coastguard Worker             stride_,
1964*da0073e9SAndroid Build Coastguard Worker             padding_,
1965*da0073e9SAndroid Build Coastguard Worker             dilation_,
1966*da0073e9SAndroid Build Coastguard Worker             transposed_,
1967*da0073e9SAndroid Build Coastguard Worker             output_padding_,
1968*da0073e9SAndroid Build Coastguard Worker             std::move(groups_),
1969*da0073e9SAndroid Build Coastguard Worker             ref);
1970*da0073e9SAndroid Build Coastguard Worker       },
1971*da0073e9SAndroid Build Coastguard Worker       py::arg("input"),
1972*da0073e9SAndroid Build Coastguard Worker       py::arg("weight"),
1973*da0073e9SAndroid Build Coastguard Worker       py::arg("bias"),
1974*da0073e9SAndroid Build Coastguard Worker       py::arg("stride"),
1975*da0073e9SAndroid Build Coastguard Worker       py::arg("padding"),
1976*da0073e9SAndroid Build Coastguard Worker       py::arg("dilation"),
1977*da0073e9SAndroid Build Coastguard Worker       py::arg("transposed"),
1978*da0073e9SAndroid Build Coastguard Worker       py::arg("output_padding"),
1979*da0073e9SAndroid Build Coastguard Worker       py::arg("groups"),
1980*da0073e9SAndroid Build Coastguard Worker       py::arg("bias_sizes"));
1981*da0073e9SAndroid Build Coastguard Worker 
1982*da0073e9SAndroid Build Coastguard Worker   py_module.def(
1983*da0073e9SAndroid Build Coastguard Worker       "_conv_determine_backend_memory_format",
1984*da0073e9SAndroid Build Coastguard Worker       at::native::_determine_backend_memory_format);
1985*da0073e9SAndroid Build Coastguard Worker 
1986*da0073e9SAndroid Build Coastguard Worker   ////////////////////////////////////////////////////////////////////////////////
1987*da0073e9SAndroid Build Coastguard Worker   // Scaled Dot Product Attention utilities
1988*da0073e9SAndroid Build Coastguard Worker   ////////////////////////////////////////////////////////////////////////////////
1989*da0073e9SAndroid Build Coastguard Worker   py::class_<sdp::sdp_params>(py_module, "_SDPAParams")
1990*da0073e9SAndroid Build Coastguard Worker       .def(py::init([](at::Tensor const& query,
1991*da0073e9SAndroid Build Coastguard Worker                        at::Tensor const& key,
1992*da0073e9SAndroid Build Coastguard Worker                        at::Tensor const& value,
1993*da0073e9SAndroid Build Coastguard Worker                        std::optional<at::Tensor> attn_mask,
1994*da0073e9SAndroid Build Coastguard Worker                        double dropout,
1995*da0073e9SAndroid Build Coastguard Worker                        bool is_causal,
1996*da0073e9SAndroid Build Coastguard Worker                        bool enable_gqa) {
1997*da0073e9SAndroid Build Coastguard Worker         return sdp::sdp_params{
1998*da0073e9SAndroid Build Coastguard Worker             query,
1999*da0073e9SAndroid Build Coastguard Worker             key,
2000*da0073e9SAndroid Build Coastguard Worker             value,
2001*da0073e9SAndroid Build Coastguard Worker             std::move(attn_mask),
2002*da0073e9SAndroid Build Coastguard Worker             dropout,
2003*da0073e9SAndroid Build Coastguard Worker             is_causal,
2004*da0073e9SAndroid Build Coastguard Worker             enable_gqa};
2005*da0073e9SAndroid Build Coastguard Worker       }))
2006*da0073e9SAndroid Build Coastguard Worker       .def_readonly("query", &sdp::sdp_params::query)
2007*da0073e9SAndroid Build Coastguard Worker       .def_readonly("key", &sdp::sdp_params::key)
2008*da0073e9SAndroid Build Coastguard Worker       .def_readonly("value", &sdp::sdp_params::value)
2009*da0073e9SAndroid Build Coastguard Worker       .def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
2010*da0073e9SAndroid Build Coastguard Worker       .def_readonly("dropout", &sdp::sdp_params::dropout)
2011*da0073e9SAndroid Build Coastguard Worker       .def_readonly("is_causal", &sdp::sdp_params::is_causal)
2012*da0073e9SAndroid Build Coastguard Worker       .def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa);
2013*da0073e9SAndroid Build Coastguard Worker 
2014*da0073e9SAndroid Build Coastguard Worker   py::enum_<sdp::SDPBackend>(
2015*da0073e9SAndroid Build Coastguard Worker       py_module,
2016*da0073e9SAndroid Build Coastguard Worker       "_SDPBackend",
2017*da0073e9SAndroid Build Coastguard Worker       "An enum-like class that contains the different backends for scaled dot product attention.\n\n... warning:: This class is in beta and subject to change.\n\n"
2018*da0073e9SAndroid Build Coastguard Worker       "This backend class is designed to be used with the sdpa_kernel context manager."
2019*da0073e9SAndroid Build Coastguard Worker       "See :func: torch.nn.attention.sdpa_kernel for more details.")
2020*da0073e9SAndroid Build Coastguard Worker       .value("ERROR", sdp::SDPBackend::error)
2021*da0073e9SAndroid Build Coastguard Worker       .value("MATH", sdp::SDPBackend::math)
2022*da0073e9SAndroid Build Coastguard Worker       .value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention)
2023*da0073e9SAndroid Build Coastguard Worker       .value("EFFICIENT_ATTENTION", sdp::SDPBackend::efficient_attention)
2024*da0073e9SAndroid Build Coastguard Worker       .value("CUDNN_ATTENTION", sdp::SDPBackend::cudnn_attention)
2025*da0073e9SAndroid Build Coastguard Worker       .value("OVERRIDEABLE", sdp::SDPBackend::overrideable);
2026*da0073e9SAndroid Build Coastguard Worker 
2027*da0073e9SAndroid Build Coastguard Worker   py_module.def("_is_flash_attention_available", []() {
2028*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
2029*da0073e9SAndroid Build Coastguard Worker     return sdp::is_flash_attention_available();
2030*da0073e9SAndroid Build Coastguard Worker #else
2031*da0073e9SAndroid Build Coastguard Worker     return false;
2032*da0073e9SAndroid Build Coastguard Worker #endif
2033*da0073e9SAndroid Build Coastguard Worker   });
2034*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2035*da0073e9SAndroid Build Coastguard Worker       "_can_use_flash_attention",
2036*da0073e9SAndroid Build Coastguard Worker       [](const sdp::sdp_params& params, bool debug) {
2037*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
2038*da0073e9SAndroid Build Coastguard Worker         return sdp::can_use_flash_attention(params, debug);
2039*da0073e9SAndroid Build Coastguard Worker #else
2040*da0073e9SAndroid Build Coastguard Worker         return false;
2041*da0073e9SAndroid Build Coastguard Worker #endif
2042*da0073e9SAndroid Build Coastguard Worker       });
2043*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2044*da0073e9SAndroid Build Coastguard Worker       "_can_use_mem_efficient_attention",
2045*da0073e9SAndroid Build Coastguard Worker       [](const sdp::sdp_params& params, bool debug) {
2046*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
2047*da0073e9SAndroid Build Coastguard Worker         return sdp::can_use_mem_efficient_attention(params, debug);
2048*da0073e9SAndroid Build Coastguard Worker #else
2049*da0073e9SAndroid Build Coastguard Worker         return false;
2050*da0073e9SAndroid Build Coastguard Worker #endif
2051*da0073e9SAndroid Build Coastguard Worker       });
2052*da0073e9SAndroid Build Coastguard Worker 
2053*da0073e9SAndroid Build Coastguard Worker   py::enum_<at::LinalgBackend>(py_module, "_LinalgBackend")
2054*da0073e9SAndroid Build Coastguard Worker       .value("Default", at::LinalgBackend::Default)
2055*da0073e9SAndroid Build Coastguard Worker       .value("Cusolver", at::LinalgBackend::Cusolver)
2056*da0073e9SAndroid Build Coastguard Worker       .value("Magma", at::LinalgBackend::Magma);
2057*da0073e9SAndroid Build Coastguard Worker 
2058*da0073e9SAndroid Build Coastguard Worker   py_module.def("_set_linalg_preferred_backend", [](at::LinalgBackend b) {
2059*da0073e9SAndroid Build Coastguard Worker     at::globalContext().setLinalgPreferredBackend(b);
2060*da0073e9SAndroid Build Coastguard Worker   });
2061*da0073e9SAndroid Build Coastguard Worker   py_module.def("_get_linalg_preferred_backend", []() {
2062*da0073e9SAndroid Build Coastguard Worker     return at::globalContext().linalgPreferredBackend();
2063*da0073e9SAndroid Build Coastguard Worker   });
2064*da0073e9SAndroid Build Coastguard Worker 
2065*da0073e9SAndroid Build Coastguard Worker   py::enum_<at::BlasBackend>(py_module, "_BlasBackend")
2066*da0073e9SAndroid Build Coastguard Worker       .value("Cublas", at::BlasBackend::Cublas)
2067*da0073e9SAndroid Build Coastguard Worker       .value("Cublaslt", at::BlasBackend::Cublaslt);
2068*da0073e9SAndroid Build Coastguard Worker 
2069*da0073e9SAndroid Build Coastguard Worker   py_module.def("_set_blas_preferred_backend", [](at::BlasBackend b) {
2070*da0073e9SAndroid Build Coastguard Worker     at::globalContext().setBlasPreferredBackend(b);
2071*da0073e9SAndroid Build Coastguard Worker   });
2072*da0073e9SAndroid Build Coastguard Worker   py_module.def("_get_blas_preferred_backend", []() {
2073*da0073e9SAndroid Build Coastguard Worker     return at::globalContext().blasPreferredBackend();
2074*da0073e9SAndroid Build Coastguard Worker   });
2075*da0073e9SAndroid Build Coastguard Worker 
2076*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2077*da0073e9SAndroid Build Coastguard Worker       "_construct_storage_from_data_pointer",
2078*da0073e9SAndroid Build Coastguard Worker       [](int64_t data_ptr, c10::Device device, size_t size_bytes) {
2079*da0073e9SAndroid Build Coastguard Worker         return c10::Storage(
2080*da0073e9SAndroid Build Coastguard Worker             c10::Storage::use_byte_size_t(),
2081*da0073e9SAndroid Build Coastguard Worker             size_bytes,
2082*da0073e9SAndroid Build Coastguard Worker             // NOLINTNEXTLINE(performance-no-int-to-ptr)
2083*da0073e9SAndroid Build Coastguard Worker             at::DataPtr(reinterpret_cast<void*>(data_ptr), device));
2084*da0073e9SAndroid Build Coastguard Worker       });
2085*da0073e9SAndroid Build Coastguard Worker 
2086*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2087*da0073e9SAndroid Build Coastguard Worker       "_stash_obj_in_tls", [](const std::string& key, py::handle arg) {
2088*da0073e9SAndroid Build Coastguard Worker         at::impl::ThreadLocalPythonObjects::get_state().set(
2089*da0073e9SAndroid Build Coastguard Worker             key,
2090*da0073e9SAndroid Build Coastguard Worker             std::make_shared<c10::SafePyObject>(arg.ptr(), getPyInterpreter()));
2091*da0073e9SAndroid Build Coastguard Worker       });
2092*da0073e9SAndroid Build Coastguard Worker 
2093*da0073e9SAndroid Build Coastguard Worker   py_module.def("_get_obj_in_tls", [](const std::string& key) -> py::handle {
2094*da0073e9SAndroid Build Coastguard Worker     auto safe_pyobject =
2095*da0073e9SAndroid Build Coastguard Worker         at::impl::ThreadLocalPythonObjects::get_state().get(key);
2096*da0073e9SAndroid Build Coastguard Worker     auto obj = safe_pyobject->ptr(getPyInterpreter());
2097*da0073e9SAndroid Build Coastguard Worker     return py::handle(obj);
2098*da0073e9SAndroid Build Coastguard Worker   });
2099*da0073e9SAndroid Build Coastguard Worker 
2100*da0073e9SAndroid Build Coastguard Worker   py_module.def("_is_key_in_tls", [](const std::string& key) -> bool {
2101*da0073e9SAndroid Build Coastguard Worker     return at::impl::ThreadLocalPythonObjects::get_state().contains(key);
2102*da0073e9SAndroid Build Coastguard Worker   });
2103*da0073e9SAndroid Build Coastguard Worker 
2104*da0073e9SAndroid Build Coastguard Worker   py_module.def("_accelerator_hooks_device_count", []() {
2105*da0073e9SAndroid Build Coastguard Worker     auto device_type = at::getAccelerator();
2106*da0073e9SAndroid Build Coastguard Worker     if (device_type.has_value()) {
2107*da0073e9SAndroid Build Coastguard Worker       return at::globalContext()
2108*da0073e9SAndroid Build Coastguard Worker           .getAcceleratorHooksInterface(device_type.value())
2109*da0073e9SAndroid Build Coastguard Worker           .deviceCount();
2110*da0073e9SAndroid Build Coastguard Worker     }
2111*da0073e9SAndroid Build Coastguard Worker     return c10::DeviceIndex(-1);
2112*da0073e9SAndroid Build Coastguard Worker   });
2113*da0073e9SAndroid Build Coastguard Worker 
2114*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2115*da0073e9SAndroid Build Coastguard Worker       "_accelerator_hooks_set_current_device",
2116*da0073e9SAndroid Build Coastguard Worker       [](c10::DeviceIndex device_index) {
2117*da0073e9SAndroid Build Coastguard Worker         auto device_type = at::getAccelerator();
2118*da0073e9SAndroid Build Coastguard Worker         if (device_type.has_value()) {
2119*da0073e9SAndroid Build Coastguard Worker           at::globalContext()
2120*da0073e9SAndroid Build Coastguard Worker               .getAcceleratorHooksInterface(device_type.value())
2121*da0073e9SAndroid Build Coastguard Worker               .setCurrentDevice(device_index);
2122*da0073e9SAndroid Build Coastguard Worker         }
2123*da0073e9SAndroid Build Coastguard Worker       });
2124*da0073e9SAndroid Build Coastguard Worker 
2125*da0073e9SAndroid Build Coastguard Worker   py_module.def("_accelerator_hooks_get_current_device", []() {
2126*da0073e9SAndroid Build Coastguard Worker     auto device_type = at::getAccelerator();
2127*da0073e9SAndroid Build Coastguard Worker     if (device_type.has_value()) {
2128*da0073e9SAndroid Build Coastguard Worker       return at::globalContext()
2129*da0073e9SAndroid Build Coastguard Worker           .getAcceleratorHooksInterface(device_type.value())
2130*da0073e9SAndroid Build Coastguard Worker           .getCurrentDevice();
2131*da0073e9SAndroid Build Coastguard Worker     }
2132*da0073e9SAndroid Build Coastguard Worker     return c10::DeviceIndex(-1);
2133*da0073e9SAndroid Build Coastguard Worker   });
2134*da0073e9SAndroid Build Coastguard Worker 
2135*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2136*da0073e9SAndroid Build Coastguard Worker       "_accelerator_hooks_exchange_device", [](c10::DeviceIndex device_index) {
2137*da0073e9SAndroid Build Coastguard Worker         auto device_type = at::getAccelerator();
2138*da0073e9SAndroid Build Coastguard Worker         if (device_type.has_value()) {
2139*da0073e9SAndroid Build Coastguard Worker           return at::globalContext()
2140*da0073e9SAndroid Build Coastguard Worker               .getAcceleratorHooksInterface(device_type.value())
2141*da0073e9SAndroid Build Coastguard Worker               .exchangeDevice(device_index);
2142*da0073e9SAndroid Build Coastguard Worker         }
2143*da0073e9SAndroid Build Coastguard Worker         return c10::DeviceIndex(-1);
2144*da0073e9SAndroid Build Coastguard Worker       });
2145*da0073e9SAndroid Build Coastguard Worker 
2146*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2147*da0073e9SAndroid Build Coastguard Worker       "_accelerator_hooks_maybe_exchange_device",
2148*da0073e9SAndroid Build Coastguard Worker       [](c10::DeviceIndex device_index) {
2149*da0073e9SAndroid Build Coastguard Worker         auto device_type = at::getAccelerator();
2150*da0073e9SAndroid Build Coastguard Worker         if (device_type.has_value()) {
2151*da0073e9SAndroid Build Coastguard Worker           return at::globalContext()
2152*da0073e9SAndroid Build Coastguard Worker               .getAcceleratorHooksInterface(device_type.value())
2153*da0073e9SAndroid Build Coastguard Worker               .maybeExchangeDevice(device_index);
2154*da0073e9SAndroid Build Coastguard Worker         }
2155*da0073e9SAndroid Build Coastguard Worker         return c10::DeviceIndex(-1);
2156*da0073e9SAndroid Build Coastguard Worker       });
2157*da0073e9SAndroid Build Coastguard Worker 
2158*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2159*da0073e9SAndroid Build Coastguard Worker       "_get_accelerator",
2160*da0073e9SAndroid Build Coastguard Worker       [](std::optional<bool> check = std::nullopt) {
2161*da0073e9SAndroid Build Coastguard Worker         return c10::Device(
2162*da0073e9SAndroid Build Coastguard Worker             at::getAccelerator(check.value_or(false))
2163*da0073e9SAndroid Build Coastguard Worker                 .value_or(c10::DeviceType::CPU),
2164*da0073e9SAndroid Build Coastguard Worker             -1);
2165*da0073e9SAndroid Build Coastguard Worker       },
2166*da0073e9SAndroid Build Coastguard Worker       py::arg("check") = nullptr);
2167*da0073e9SAndroid Build Coastguard Worker 
2168*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
2169*da0073e9SAndroid Build Coastguard Worker   PyObject* has_cuda = Py_True;
2170*da0073e9SAndroid Build Coastguard Worker #else
2171*da0073e9SAndroid Build Coastguard Worker   PyObject* has_cuda = Py_False;
2172*da0073e9SAndroid Build Coastguard Worker #endif
2173*da0073e9SAndroid Build Coastguard Worker 
2174*da0073e9SAndroid Build Coastguard Worker #ifdef USE_MPS
2175*da0073e9SAndroid Build Coastguard Worker   PyObject* has_mps = Py_True;
2176*da0073e9SAndroid Build Coastguard Worker #else
2177*da0073e9SAndroid Build Coastguard Worker   PyObject* has_mps = Py_False;
2178*da0073e9SAndroid Build Coastguard Worker #endif
2179*da0073e9SAndroid Build Coastguard Worker 
2180*da0073e9SAndroid Build Coastguard Worker #ifdef USE_XPU
2181*da0073e9SAndroid Build Coastguard Worker   PyObject* has_xpu = Py_True;
2182*da0073e9SAndroid Build Coastguard Worker #else
2183*da0073e9SAndroid Build Coastguard Worker   PyObject* has_xpu = Py_False;
2184*da0073e9SAndroid Build Coastguard Worker #endif
2185*da0073e9SAndroid Build Coastguard Worker 
2186*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_has_cuda", has_cuda));
2187*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
2188*da0073e9SAndroid Build Coastguard Worker       set_module_attr("_has_magma", at::hasMAGMA() ? Py_True : Py_False));
2189*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_has_mps", has_mps));
2190*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_has_xpu", has_xpu));
2191*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
2192*da0073e9SAndroid Build Coastguard Worker       set_module_attr("_has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
2193*da0073e9SAndroid Build Coastguard Worker 
2194*da0073e9SAndroid Build Coastguard Worker #ifdef _GLIBCXX_USE_CXX11_ABI
2195*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr(
2196*da0073e9SAndroid Build Coastguard Worker       "_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False));
2197*da0073e9SAndroid Build Coastguard Worker #else
2198*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False));
2199*da0073e9SAndroid Build Coastguard Worker #endif
2200*da0073e9SAndroid Build Coastguard Worker 
2201*da0073e9SAndroid Build Coastguard Worker // See note [Pybind11 ABI constants]
2202*da0073e9SAndroid Build Coastguard Worker #define SET_STR_DEFINE(name) \
2203*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_" #name, THPUtils_packString(name)))
2204*da0073e9SAndroid Build Coastguard Worker 
2205*da0073e9SAndroid Build Coastguard Worker #ifdef PYBIND11_COMPILER_TYPE
2206*da0073e9SAndroid Build Coastguard Worker   SET_STR_DEFINE(PYBIND11_COMPILER_TYPE);
2207*da0073e9SAndroid Build Coastguard Worker #else
2208*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
2209*da0073e9SAndroid Build Coastguard Worker       set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None));
2210*da0073e9SAndroid Build Coastguard Worker #endif
2211*da0073e9SAndroid Build Coastguard Worker 
2212*da0073e9SAndroid Build Coastguard Worker #ifdef PYBIND11_STDLIB
2213*da0073e9SAndroid Build Coastguard Worker   SET_STR_DEFINE(PYBIND11_STDLIB);
2214*da0073e9SAndroid Build Coastguard Worker #else
2215*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None));
2216*da0073e9SAndroid Build Coastguard Worker #endif
2217*da0073e9SAndroid Build Coastguard Worker 
2218*da0073e9SAndroid Build Coastguard Worker #ifdef PYBIND11_BUILD_ABI
2219*da0073e9SAndroid Build Coastguard Worker   SET_STR_DEFINE(PYBIND11_BUILD_ABI);
2220*da0073e9SAndroid Build Coastguard Worker #else
2221*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None));
2222*da0073e9SAndroid Build Coastguard Worker #endif
2223*da0073e9SAndroid Build Coastguard Worker #undef SET_STR_DEFINE
2224*da0073e9SAndroid Build Coastguard Worker 
2225*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2226*da0073e9SAndroid Build Coastguard Worker       "_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); });
2227*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2228*da0073e9SAndroid Build Coastguard Worker       "_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); });
2229*da0073e9SAndroid Build Coastguard Worker   py_module.def("_get_tensor_metadata", &torch::jit::getTensorMetadata);
2230*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2231*da0073e9SAndroid Build Coastguard Worker       "_set_tensor_metadata",
2232*da0073e9SAndroid Build Coastguard Worker       static_cast<void (*)(
2233*da0073e9SAndroid Build Coastguard Worker           const at::Tensor&, std::unordered_map<std::string, bool>)>(
2234*da0073e9SAndroid Build Coastguard Worker           torch::jit::setTensorMetadata));
2235*da0073e9SAndroid Build Coastguard Worker   py_module.def("_dispatch_key_set", [](const at::Tensor& x) {
2236*da0073e9SAndroid Build Coastguard Worker     return toString(x.key_set());
2237*da0073e9SAndroid Build Coastguard Worker   });
2238*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2239*da0073e9SAndroid Build Coastguard Worker       "_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
2240*da0073e9SAndroid Build Coastguard Worker 
2241*da0073e9SAndroid Build Coastguard Worker   py_module.def("_set_meta_in_tls_dispatch_include", [](bool meta_in_tls) {
2242*da0073e9SAndroid Build Coastguard Worker     auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2243*da0073e9SAndroid Build Coastguard Worker     c10::DispatchKeySet key_set({at::DispatchKey::Meta});
2244*da0073e9SAndroid Build Coastguard Worker     if (meta_in_tls) {
2245*da0073e9SAndroid Build Coastguard Worker       local_keyset.included_ = local_keyset.included_ | key_set;
2246*da0073e9SAndroid Build Coastguard Worker     } else {
2247*da0073e9SAndroid Build Coastguard Worker       local_keyset.included_ =
2248*da0073e9SAndroid Build Coastguard Worker           local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit);
2249*da0073e9SAndroid Build Coastguard Worker     }
2250*da0073e9SAndroid Build Coastguard Worker     c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
2251*da0073e9SAndroid Build Coastguard Worker   });
2252*da0073e9SAndroid Build Coastguard Worker 
2253*da0073e9SAndroid Build Coastguard Worker   py_module.def("_meta_in_tls_dispatch_include", []() {
2254*da0073e9SAndroid Build Coastguard Worker     auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2255*da0073e9SAndroid Build Coastguard Worker     return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit);
2256*da0073e9SAndroid Build Coastguard Worker   });
2257*da0073e9SAndroid Build Coastguard Worker 
2258*da0073e9SAndroid Build Coastguard Worker   py_module.def("_dump_local_tls_set", []() {
2259*da0073e9SAndroid Build Coastguard Worker     auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2260*da0073e9SAndroid Build Coastguard Worker     std::cout << "Included: " << toString(local_keyset.included_) << "\n";
2261*da0073e9SAndroid Build Coastguard Worker     std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
2262*da0073e9SAndroid Build Coastguard Worker   });
2263*da0073e9SAndroid Build Coastguard Worker 
2264*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2265*da0073e9SAndroid Build Coastguard Worker       "_should_allow_numbers_as_tensors", [](const std::string& name) {
2266*da0073e9SAndroid Build Coastguard Worker         return torch::should_allow_numbers_as_tensors(name);
2267*da0073e9SAndroid Build Coastguard Worker       });
2268*da0073e9SAndroid Build Coastguard Worker 
2269*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2270*da0073e9SAndroid Build Coastguard Worker       "_group_tensors_by_device_and_dtype",
2271*da0073e9SAndroid Build Coastguard Worker       [](const std::vector<std::vector<std::optional<at::Tensor>>>&
2272*da0073e9SAndroid Build Coastguard Worker              nested_tensorlist,
2273*da0073e9SAndroid Build Coastguard Worker          const bool with_indices) {
2274*da0073e9SAndroid Build Coastguard Worker         return at::native::_group_tensors_by_first_tensors_device_and_dtype(
2275*da0073e9SAndroid Build Coastguard Worker             nested_tensorlist, with_indices);
2276*da0073e9SAndroid Build Coastguard Worker       });
2277*da0073e9SAndroid Build Coastguard Worker 
2278*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2279*da0073e9SAndroid Build Coastguard Worker       "_storage_address",
2280*da0073e9SAndroid Build Coastguard Worker       [](const at::Tensor& tensor) {
2281*da0073e9SAndroid Build Coastguard Worker         return reinterpret_cast<std::intptr_t>(
2282*da0073e9SAndroid Build Coastguard Worker             tensor.storage().unsafeGetStorageImpl());
2283*da0073e9SAndroid Build Coastguard Worker       },
2284*da0073e9SAndroid Build Coastguard Worker       "Gets the memory address of the Tensor's StorageImpl.");
2285*da0073e9SAndroid Build Coastguard Worker 
2286*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2287*da0073e9SAndroid Build Coastguard Worker       "_data_address",
2288*da0073e9SAndroid Build Coastguard Worker       [](const at::Tensor& tensor) {
2289*da0073e9SAndroid Build Coastguard Worker         return reinterpret_cast<std::intptr_t>(tensor.storage().data());
2290*da0073e9SAndroid Build Coastguard Worker       },
2291*da0073e9SAndroid Build Coastguard Worker       "Gets the memory address of the Tensor's data pointer.");
2292*da0073e9SAndroid Build Coastguard Worker 
2293*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2294*da0073e9SAndroid Build Coastguard Worker       "_is_cow_tensor",
2295*da0073e9SAndroid Build Coastguard Worker       [](const at::Tensor& tensor) {
2296*da0073e9SAndroid Build Coastguard Worker         return c10::impl::cow::is_cow_data_ptr(tensor.storage().data_ptr());
2297*da0073e9SAndroid Build Coastguard Worker       },
2298*da0073e9SAndroid Build Coastguard Worker       "Checks if a tensor's data pointer is COW");
2299*da0073e9SAndroid Build Coastguard Worker 
2300*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2301*da0073e9SAndroid Build Coastguard Worker       "_get_cudnn_batch_norm_reserve_space_size",
2302*da0073e9SAndroid Build Coastguard Worker       [](const at::Tensor& input, bool training) {
2303*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
2304*da0073e9SAndroid Build Coastguard Worker         return at::native::_get_cudnn_batch_norm_reserve_space_size(
2305*da0073e9SAndroid Build Coastguard Worker             input, training);
2306*da0073e9SAndroid Build Coastguard Worker #else
2307*da0073e9SAndroid Build Coastguard Worker         TORCH_CHECK(false, "PyTorch was not built with cuda");
2308*da0073e9SAndroid Build Coastguard Worker #endif
2309*da0073e9SAndroid Build Coastguard Worker       },
2310*da0073e9SAndroid Build Coastguard Worker       py::arg("input"),
2311*da0073e9SAndroid Build Coastguard Worker       py::arg("training"));
2312*da0073e9SAndroid Build Coastguard Worker 
2313*da0073e9SAndroid Build Coastguard Worker   py::enum_<at::native::BatchNormBackend>(py_module, "_BatchNormBackend")
2314*da0073e9SAndroid Build Coastguard Worker       .value("Native", at::native::BatchNormBackend::Native)
2315*da0073e9SAndroid Build Coastguard Worker       .value("Cudnn", at::native::BatchNormBackend::Cudnn)
2316*da0073e9SAndroid Build Coastguard Worker       .value("Miopen", at::native::BatchNormBackend::Miopen);
2317*da0073e9SAndroid Build Coastguard Worker 
2318*da0073e9SAndroid Build Coastguard Worker   py_module.def(
2319*da0073e9SAndroid Build Coastguard Worker       "_select_batch_norm_backend",
2320*da0073e9SAndroid Build Coastguard Worker       [](const at::Tensor& input,
2321*da0073e9SAndroid Build Coastguard Worker          const at::Tensor& weight,
2322*da0073e9SAndroid Build Coastguard Worker          const at::Tensor& bias,
2323*da0073e9SAndroid Build Coastguard Worker          const at::Tensor& running_mean,
2324*da0073e9SAndroid Build Coastguard Worker          const at::Tensor& running_var,
2325*da0073e9SAndroid Build Coastguard Worker          bool training,
2326*da0073e9SAndroid Build Coastguard Worker          double eps) {
2327*da0073e9SAndroid Build Coastguard Worker         return at::native::_select_batch_norm_backend(
2328*da0073e9SAndroid Build Coastguard Worker             input, weight, bias, running_mean, running_var, training, eps);
2329*da0073e9SAndroid Build Coastguard Worker       },
2330*da0073e9SAndroid Build Coastguard Worker       py::arg("input"),
2331*da0073e9SAndroid Build Coastguard Worker       py::arg("weight"),
2332*da0073e9SAndroid Build Coastguard Worker       py::arg("bias"),
2333*da0073e9SAndroid Build Coastguard Worker       py::arg("running_mean"),
2334*da0073e9SAndroid Build Coastguard Worker       py::arg("running_var"),
2335*da0073e9SAndroid Build Coastguard Worker       py::arg("training"),
2336*da0073e9SAndroid Build Coastguard Worker       py::arg("eps"));
2337*da0073e9SAndroid Build Coastguard Worker 
2338*da0073e9SAndroid Build Coastguard Worker   const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
2339*da0073e9SAndroid Build Coastguard Worker   THPDefaultCPUGenerator =
2340*da0073e9SAndroid Build Coastguard Worker       (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
2341*da0073e9SAndroid Build Coastguard Worker   // This reference is meant to be given away, so no need to incref here.
2342*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr(
2343*da0073e9SAndroid Build Coastguard Worker       "default_generator",
2344*da0073e9SAndroid Build Coastguard Worker       (PyObject*)THPDefaultCPUGenerator,
2345*da0073e9SAndroid Build Coastguard Worker       /* incref= */ false));
2346*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr(
2347*da0073e9SAndroid Build Coastguard Worker       "DisableTorchFunctionSubclass",
2348*da0073e9SAndroid Build Coastguard Worker       (PyObject*)THPModule_DisableTorchFunctionSubclassType(),
2349*da0073e9SAndroid Build Coastguard Worker       /* incref= */ false));
2350*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(set_module_attr(
2351*da0073e9SAndroid Build Coastguard Worker       "DisableTorchFunction",
2352*da0073e9SAndroid Build Coastguard Worker       (PyObject*)THPModule_DisableTorchFunctionType(),
2353*da0073e9SAndroid Build Coastguard Worker       /* incref= */ false));
2354*da0073e9SAndroid Build Coastguard Worker   torch::set_disabled_torch_function_impl(
2355*da0073e9SAndroid Build Coastguard Worker       PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
2356*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr);
2357*da0073e9SAndroid Build Coastguard Worker   torch::set_disabled_torch_dispatch_impl(
2358*da0073e9SAndroid Build Coastguard Worker       PyObject_GetAttrString(module, "_disabled_torch_dispatch_impl"));
2359*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::disabled_torch_dispatch_impl() != nullptr);
2360*da0073e9SAndroid Build Coastguard Worker   return module;
2361*da0073e9SAndroid Build Coastguard Worker   END_HANDLE_TH_ERRORS
2362*da0073e9SAndroid Build Coastguard Worker }
2363*da0073e9SAndroid Build Coastguard Worker 
2364*da0073e9SAndroid Build Coastguard Worker // Checks that the _C shared library isn't initialized multiple times. This
2365*da0073e9SAndroid Build Coastguard Worker // can happen if the same csrc files are compiled into multiple shared
2366*da0073e9SAndroid Build Coastguard Worker // libraries.
pytorch_duplicate_guard()2367*da0073e9SAndroid Build Coastguard Worker inline void pytorch_duplicate_guard() {
2368*da0073e9SAndroid Build Coastguard Worker   static int initialized = 0;
2369*da0073e9SAndroid Build Coastguard Worker   if (initialized) {
2370*da0073e9SAndroid Build Coastguard Worker     fmt::print(stderr, "pytorch: _C shared library re-initialized\n");
2371*da0073e9SAndroid Build Coastguard Worker     abort();
2372*da0073e9SAndroid Build Coastguard Worker   }
2373*da0073e9SAndroid Build Coastguard Worker   initialized = 1;
2374*da0073e9SAndroid Build Coastguard Worker   ;
2375*da0073e9SAndroid Build Coastguard Worker }
2376*da0073e9SAndroid Build Coastguard Worker 
2377*da0073e9SAndroid Build Coastguard Worker struct call_duplicate_guard {
call_duplicate_guardcall_duplicate_guard2378*da0073e9SAndroid Build Coastguard Worker   call_duplicate_guard() {
2379*da0073e9SAndroid Build Coastguard Worker     pytorch_duplicate_guard();
2380*da0073e9SAndroid Build Coastguard Worker   }
2381*da0073e9SAndroid Build Coastguard Worker };
2382*da0073e9SAndroid Build Coastguard Worker 
2383*da0073e9SAndroid Build Coastguard Worker static call_duplicate_guard _call_duplicate_guard;
2384