xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/PythonOpRegistrationTrampoline.h>
2 
3 namespace at::impl {
4 
5 // The strategy is that all python interpreters attempt to register themselves
6 // as the main interpreter, but only one wins.  Only that interpreter is
7 // allowed to interact with the C++ dispatcher.  Furthermore, when we execute
8 // logic on that interpreter, we do so hermetically, never setting pyobj field
9 // on Tensor.
10 
11 std::atomic<c10::impl::PyInterpreter*>
12     PythonOpRegistrationTrampoline::interpreter_{nullptr};
13 
getInterpreter()14 c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::getInterpreter() {
15   return PythonOpRegistrationTrampoline::interpreter_.load();
16 }
17 
registerInterpreter(c10::impl::PyInterpreter * interp)18 bool PythonOpRegistrationTrampoline::registerInterpreter(
19     c10::impl::PyInterpreter* interp) {
20   c10::impl::PyInterpreter* expected = nullptr;
21   interpreter_.compare_exchange_strong(expected, interp);
22   if (expected != nullptr) {
23     // This is the second (or later) Python interpreter, which means we need
24     // non-trivial hermetic PyObject TLS
25     c10::impl::HermeticPyObjectTLS::init_state();
26     return false;
27   } else {
28     return true;
29   }
30 }
31 
32 } // namespace at::impl
33