xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_engine.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <torch/csrc/autograd/python_engine.h>
2  
3  #include <ATen/LegacyBatchedTensorImpl.h>
4  #include <ATen/LegacyVmapMode.h>
5  #include <c10/util/irange.h>
6  #include <pybind11/pybind11.h>
7  #include <torch/csrc/DynamicTypes.h>
8  #include <torch/csrc/THP.h>
9  #include <torch/csrc/autograd/edge.h>
10  #include <torch/csrc/autograd/engine.h>
11  #include <torch/csrc/autograd/function.h>
12  #include <torch/csrc/autograd/functions/basic_ops.h>
13  #include <torch/csrc/autograd/python_anomaly_mode.h>
14  #include <torch/csrc/autograd/python_cpp_function.h>
15  #include <torch/csrc/autograd/python_function.h>
16  #include <torch/csrc/autograd/python_saved_variable_hooks.h>
17  #include <torch/csrc/utils/pybind.h>
18  #include <torch/csrc/utils/pycfunction_helpers.h>
19  
20  #ifndef _WIN32
21  #include <pthread.h>
22  #endif
23  
24  #include <memory> // for unique_ptr
25  #include <utility>
26  
27  using namespace torch::autograd;
28  
29  struct THPEngine {
30    PyObject_HEAD
31  };
32  
33  static bool _reinitialize_engine = false;
34  
35  namespace torch::autograd::python {
36  
37  PythonEngine::PythonEngine() = default;
38  
get_python_engine()39  Engine& PythonEngine::get_python_engine() {
40    static PythonEngine engine;
41    // This is "probably" thread-safe because the flag is set in a fork handler
42    // before any threads are created, and this function is only called with the
43    // GIL held. However, using fork + threads is playing with fire so this is
44    // more of a "best effort" thing. For example, if the fork occurs while the
45    // backwards threads hold a lock, we'll probably deadlock in the engine
46    // destructor.
47    if (_reinitialize_engine) {
48      engine.release_workers();
49      engine.~PythonEngine();
50      new (&engine) torch::autograd::python::PythonEngine();
51      _reinitialize_engine = false;
52    }
53    return engine;
54  }
55  
~PythonEngine()56  PythonEngine::~PythonEngine() {
57    Engine::stop();
58  }
59  
60  #if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9
61  #define IS_PYTHON_3_9_PLUS
62  #endif
63  
thread_init(int device,const std::shared_ptr<ReadyQueue> & ready_queue,bool should_increment)64  void PythonEngine::thread_init(
65      int device,
66      const std::shared_ptr<ReadyQueue>& ready_queue,
67      bool should_increment) {
68    // Increment thread usage count before acquiring the GIL
69    if (should_increment) {
70      increment_non_reentrant_thread_count();
71    }
72    // Create a PyThreadState, but release the GIL. This lets
73    // pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL
74    // without having to create a new PyThreadState each time.
75  #if defined(IS_PYTHON_3_9_PLUS)
76    auto gil = std::make_unique<pybind11::gil_scoped_acquire>();
77  #else
78    pybind11::gil_scoped_acquire gil;
79  #endif
80    pybind11::gil_scoped_release no_gil;
81    Engine::thread_init(device, ready_queue, false);
82  
83    if (should_increment) {
84      // Decrement the count during shutdown if we incremented earlier.
85      decrement_non_reentrant_thread_count();
86    }
87  
88  #if defined(IS_PYTHON_3_9_PLUS)
89    // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if
90    // runtime is finalizing
91    if (!Py_IsInitialized()) {
92      no_gil.disarm();
93      // TODO: call disarm once PyThreadState_Clear can safely be called from
94      // finalize NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct
95      // PyThreadState, so avoid use-after-free here.
96      auto ptr = gil.release();
97      operator delete(ptr);
98    }
99  #endif
100  }
101  
thread_on_exception(const std::shared_ptr<GraphTask> & graph_task,const std::shared_ptr<Node> & fn,std::exception & e)102  void PythonEngine::thread_on_exception(
103      const std::shared_ptr<GraphTask>& graph_task,
104      const std::shared_ptr<Node>& fn,
105      std::exception& e) {
106    // See Note [ Persisting PyErr state across autograd engine threads ]
107    auto python_err = dynamic_cast<python_error*>(&e);
108    if (python_err) {
109      python_err->persist();
110    }
111    Engine::thread_on_exception(graph_task, fn, e);
112  }
113  
make_anomaly_metadata()114  std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
115    return std::make_unique<PyAnomalyMetadata>();
116  }
117  
118  std::unique_ptr<SavedVariableHooks> PythonEngine::
get_default_saved_variable_hooks()119      get_default_saved_variable_hooks() {
120    return PyDefaultSavedVariableHooks::get_hooks();
121  }
122  
execute(const edge_list & roots,const variable_list & inputs,bool keep_graph,bool create_graph,bool accumulate_grad,const edge_list & outputs)123  variable_list PythonEngine::execute(
124      const edge_list& roots,
125      const variable_list& inputs,
126      bool keep_graph,
127      bool create_graph,
128      bool accumulate_grad,
129      const edge_list& outputs) {
130    TORCH_CHECK(
131        !PyGILState_Check(),
132        "The autograd engine was called while holding the GIL. If you are using the C++ "
133        "API, the autograd engine is an expensive operation that does not require the "
134        "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
135        ". If you are not using the C++ API, please report a bug to the pytorch team.")
136    try {
137      return Engine::execute(
138          roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
139    } catch (python_error& e) {
140      e.restore();
141      throw;
142    }
143  }
144  
execute_with_graph_task(const std::shared_ptr<GraphTask> & graph_task,std::shared_ptr<Node> graph_root,InputBuffer && input_buffer)145  c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
146      const std::shared_ptr<GraphTask>& graph_task,
147      std::shared_ptr<Node> graph_root,
148      InputBuffer&& input_buffer) {
149    try {
150      return Engine::execute_with_graph_task(
151          graph_task, std::move(graph_root), std::move(input_buffer));
152    } catch (python_error& e) {
153      pybind11::gil_scoped_acquire gil;
154      if (!PyErr_Occurred()) {
155        // Set the error indicator only if it is not set already.
156        e.restore();
157      }
158      throw;
159    }
160  }
161  } // namespace torch::autograd::python
162  
163  PyObject* THPEngineClass = nullptr;
164  
parseGradientEdge(PyObject * obj,int64_t index)165  inline static Edge parseGradientEdge(PyObject* obj, int64_t index) {
166    PyObject* grad_fn = PyTuple_GetItem(obj, 0);
167    auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1));
168    std::shared_ptr<torch::autograd::Node> grad_fn_sp;
169    if (THPFunction_Check(grad_fn)) {
170      grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock();
171    } else if (THPCppFunction_Check(grad_fn)) {
172      grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata;
173    } else {
174      TORCH_CHECK(
175          false,
176          "GradientEdge's first object must be an autograd.graph.Node "
177          "but got ",
178          THPUtils_typename(grad_fn));
179    }
180    return Edge(grad_fn_sp, output_nr);
181  }
182  
183  // Implementation of torch._C._EngineBase.run_backward
THPEngine_run_backward(PyObject * self,PyObject * args,PyObject * kwargs)184  PyObject* THPEngine_run_backward(
185      PyObject* self,
186      PyObject* args,
187      PyObject* kwargs) {
188    HANDLE_TH_ERRORS
189    PyObject* tensors = nullptr;
190    PyObject* grad_tensors = nullptr;
191    unsigned char keep_graph = 0;
192    unsigned char create_graph = 0;
193    PyObject* inputs = nullptr;
194    unsigned char allow_unreachable = 0;
195    unsigned char accumulate_grad =
196        0; // Indicate whether to accumulate grad into leaf Tensors or capture
197    constexpr const char* accepted_kwargs[] = {// NOLINT
198                                               "tensors",
199                                               "grad_tensors",
200                                               "keep_graph",
201                                               "create_graph",
202                                               "inputs",
203                                               "allow_unreachable",
204                                               "accumulate_grad",
205                                               nullptr};
206    if (!PyArg_ParseTupleAndKeywords(
207            args,
208            kwargs,
209            "OObb|Obb",
210            // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,-warnings-as-errors)
211            const_cast<char**>(accepted_kwargs),
212            &tensors,
213            &grad_tensors,
214            &keep_graph,
215            &create_graph,
216            &inputs,
217            &allow_unreachable,
218            &accumulate_grad))
219      return nullptr;
220    TORCH_CHECK(
221        PyTuple_Check(tensors),
222        "tensors argument is expected to "
223        "be a tuple, but got ",
224        THPUtils_typename(tensors));
225    TORCH_CHECK(
226        PyTuple_Check(grad_tensors),
227        "grad_tensors argument is "
228        "expected to be a tuple, but got ",
229        THPUtils_typename(grad_tensors));
230  
231    Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
232    Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
233    TORCH_CHECK(
234        num_tensors == num_gradients,
235        "got ",
236        num_tensors,
237        " tensors and ",
238        num_gradients,
239        " gradients");
240  
241    // The user either called autograd.backward(...) or autograd.grad(...) to get
242    // here
243    bool backward_api_called = accumulate_grad;
244    TORCH_CHECK(
245        !backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
246        "backward() called inside torch.vmap. This is not supported, "
247        "please call backward() outside torch.vmap or instead use "
248        "torch.autograd.grad inside torch.vmap");
249  
250    edge_list roots;
251    roots.reserve(num_tensors);
252    variable_list grads;
253    grads.reserve(num_tensors);
254    for (const auto i : c10::irange(num_tensors)) {
255      PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
256      Edge gradient_edge; // Temporary variable to hold the gradient edge
257      std::optional<at::Tensor> mb_output;
258      if (THPVariable_Check(_tensor)) {
259        mb_output = THPVariable_Unpack(_tensor);
260        TORCH_CHECK(
261            !isBatchedTensor(mb_output.value()),
262            "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
263            "torch.vmap. We do not support the case where any outputs are ",
264            "vmapped tensors (output ",
265            i,
266            " is being vmapped over). Please "
267            "call autograd.grad() outside torch.vmap or file a bug report "
268            "with your use case.");
269        gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value());
270      } else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) {
271        gradient_edge = parseGradientEdge(_tensor, i);
272      } else {
273        TORCH_CHECK(
274            false,
275            "element ",
276            i,
277            " of tensors tuple is neither a Tensor nor a GradientEdge");
278      }
279      TORCH_CHECK(
280          gradient_edge.function,
281          "element ",
282          i,
283          " of tensors does not require grad and does not have a grad_fn");
284      roots.push_back(std::move(gradient_edge));
285  
286      PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i);
287      if (THPVariable_Check(grad)) {
288        const Variable& grad_var = THPVariable_Unpack(grad);
289        if (grad_var.has_names()) {
290          TORCH_WARN(
291              "Autograd was passed a named grad tensor with dims ",
292              grad_var.names(),
293              ". Autograd does not yet support named tensor semantics, so all names ",
294              "will be ignored. In practice all computed gradients will still be correct "
295              "according to regular tensor semantics.");
296        }
297        grads.push_back(grad_var);
298      } else {
299        TORCH_CHECK(
300            grad == Py_None,
301            "element ",
302            i,
303            " of gradients tuple is not a Tensor or None");
304        TORCH_CHECK(
305            mb_output.has_value(),
306            "element ",
307            i,
308            " of gradients tuple is None, but the corresponding output is a GradientEdge."
309            "This is not supported.");
310        TORCH_CHECK(
311            !mb_output.value().requires_grad(),
312            "element ",
313            i,
314            " of gradients tuple is None, but the corresponding Tensor requires grad");
315      }
316    }
317  
318    std::vector<Edge> output_edges;
319    if (inputs != nullptr) {
320      TORCH_CHECK(
321          PyTuple_CheckExact(inputs), "inputs to run_backward must be a tuple");
322      int num_inputs = PyTuple_GET_SIZE(inputs);
323      output_edges.reserve(num_inputs);
324      for (const auto i : c10::irange(num_inputs)) {
325        PyObject* input = PyTuple_GET_ITEM(inputs, i);
326        if (THPVariable_Check(input)) {
327          const auto& tensor = THPVariable_Unpack(input);
328          TORCH_CHECK(
329              !isBatchedTensor(tensor),
330              "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
331              "torch.vmap. We do not support the case where any inputs are ",
332              "vmapped tensors (input ",
333              i,
334              " is being vmapped over). Please "
335              "call autograd.grad() outside torch.vmap or file a bug report "
336              "with your use case.")
337          const auto output_nr = tensor.output_nr();
338          auto grad_fn = tensor.grad_fn();
339          if (!grad_fn) {
340            grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
341          }
342          if (accumulate_grad) {
343            tensor.retain_grad();
344          }
345          TORCH_CHECK(
346              tensor.requires_grad(),
347              "One of the differentiated Tensors does not require grad");
348          if (!grad_fn) {
349            // NOTE [ Autograd Unreachable Input ]
350            // Since input has no grad_accumulator, its guaranteed to be
351            // unreachable. We initialize an edge pointing to a non-nullptr Node
352            // so nodes in the graph (e.g., mul when an operand is scalar) that
353            // have edges pointing to nullptr don't get erroneously assigned
354            // `needed = True` in exec_info.
355            output_edges.emplace_back(std::make_shared<Identity>(), 0);
356          } else {
357            output_edges.emplace_back(grad_fn, output_nr);
358          }
359        } else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
360          output_edges.emplace_back(parseGradientEdge(input, i));
361        } else {
362          TORCH_CHECK(
363              false,
364              "all inputs have to be Tensors or GradientEdges, but got ",
365              THPUtils_typename(input));
366        }
367      }
368    }
369  
370    variable_list outputs;
371    {
372      pybind11::gil_scoped_release no_gil;
373      auto& engine = python::PythonEngine::get_python_engine();
374      outputs = engine.execute(
375          roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
376    }
377  
378    if (!backward_api_called && inputs != nullptr) {
379      int num_inputs = PyTuple_GET_SIZE(inputs);
380      THPObjectPtr py_outputs{PyTuple_New(num_inputs)};
381      if (!py_outputs)
382        return nullptr;
383      for (const auto i : c10::irange(num_inputs)) {
384        TORCH_CHECK(
385            allow_unreachable || outputs[i].defined(),
386            "One of the "
387            "differentiated Tensors appears to not have been used "
388            "in the graph. Set allow_unused=True if this is the "
389            "desired behavior.");
390        PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
391      }
392      return py_outputs.release();
393    } else {
394      Py_RETURN_NONE;
395    }
396    END_HANDLE_TH_ERRORS
397  }
398  
THPEngine_queue_callback(PyObject * self,PyObject * _callback)399  PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) {
400    HANDLE_TH_ERRORS
401    auto& engine = python::PythonEngine::get_python_engine();
402    std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) {
403      pybind11::gil_scoped_acquire gil;
404      Py_DECREF(obj);
405    });
406    Py_INCREF(_callback);
407    engine.queue_callback([callback]() {
408      pybind11::gil_scoped_acquire gil;
409      THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
410      if (!result) {
411        // Note [ Persisting PyErr state across autograd engine threads ]
412        //
413        // Since the autograd engine is multi-threaded, and Python error state is
414        // local to each thread, it must preserve the python error from the worker
415        // thread and rethrow it as-is in the calling thread. This is done via
416        // persisting the error in the two places that can encounter Python
417        // errors: (1) evaluate function and (2) queued callbacks.
418        //
419        // TODO: the engine is not actually responsible for persisting the error
420        // in the custom autograd Function case today! See the note above
421        // `raise_python_error()` function in python_function.cpp and
422        // python_hooks.cpp for more details. Persisting an extra time in the
423        // engine is fine because doing so is a no-op when the python_error has
424        // already been persisted.
425        python_error err;
426        err.persist();
427        throw std::move(err);
428      }
429    });
430    Py_RETURN_NONE;
431    END_HANDLE_TH_ERRORS
432  }
433  
THPEngine_is_checkpoint_valid(PyObject * self,PyObject * noargs)434  PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) {
435    HANDLE_TH_ERRORS
436    auto& engine = python::PythonEngine::get_python_engine();
437    if (engine.is_checkpoint_valid()) {
438      Py_RETURN_TRUE;
439    } else {
440      Py_RETURN_FALSE;
441    }
442    END_HANDLE_TH_ERRORS
443  }
444  
THPEngine_new(PyTypeObject * type,PyObject * args,PyObject * kwargs)445  PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
446    return type->tp_alloc(type, 0);
447  }
448  
449  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
450  static struct PyMethodDef THPEngine_methods[] = {
451      {(char*)"run_backward",
452       castPyCFunctionWithKeywords(THPEngine_run_backward),
453       METH_VARARGS | METH_KEYWORDS,
454       nullptr},
455      {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
456      {(char*)"is_checkpoint_valid",
457       THPEngine_is_checkpoint_valid,
458       METH_NOARGS,
459       nullptr},
460      {nullptr}};
461  
462  PyTypeObject THPEngineType = {
463      PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */
464      sizeof(THPEngine), /* tp_basicsize */
465      0, /* tp_itemsize */
466      nullptr, /* tp_dealloc */
467      0, /* tp_vectorcall_offset */
468      nullptr, /* tp_getattr */
469      nullptr, /* tp_setattr */
470      nullptr, /* tp_reserved */
471      nullptr, /* tp_repr */
472      nullptr, /* tp_as_number */
473      nullptr, /* tp_as_sequence */
474      nullptr, /* tp_as_mapping */
475      nullptr, /* tp_hash  */
476      nullptr, /* tp_call */
477      nullptr, /* tp_str */
478      nullptr, /* tp_getattro */
479      nullptr, /* tp_setattro */
480      nullptr, /* tp_as_buffer */
481      // NOLINTNEXTLINE(misc-redundant-expression)
482      Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
483      nullptr, /* tp_doc */
484      nullptr, /* tp_traverse */
485      nullptr, /* tp_clear */
486      nullptr, /* tp_richcompare */
487      0, /* tp_weaklistoffset */
488      nullptr, /* tp_iter */
489      nullptr, /* tp_iternext */
490      THPEngine_methods, /* tp_methods */
491      nullptr, /* tp_members */
492      nullptr, /* tp_getset */
493      nullptr, /* tp_base */
494      nullptr, /* tp_dict */
495      nullptr, /* tp_descr_get */
496      nullptr, /* tp_descr_set */
497      0, /* tp_dictoffset */
498      nullptr, /* tp_init */
499      nullptr, /* tp_alloc */
500      THPEngine_new /* tp_new */
501  };
502  
child_atfork()503  static void child_atfork() {
504    _reinitialize_engine = true;
505  }
506  
THPEngine_initModule(PyObject * module)507  bool THPEngine_initModule(PyObject* module) {
508  #ifndef _WIN32
509    if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
510      throw std::runtime_error("unable to set pthread_atfork handler");
511    }
512  #endif
513    if (PyType_Ready(&THPEngineType) < 0)
514      return false;
515    Py_INCREF(&THPEngineType);
516    PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType);
517    set_default_engine_stub(python::PythonEngine::get_python_engine);
518    return true;
519  }
520