1 #include <torch/csrc/autograd/python_engine.h> 2 3 #include <ATen/LegacyBatchedTensorImpl.h> 4 #include <ATen/LegacyVmapMode.h> 5 #include <c10/util/irange.h> 6 #include <pybind11/pybind11.h> 7 #include <torch/csrc/DynamicTypes.h> 8 #include <torch/csrc/THP.h> 9 #include <torch/csrc/autograd/edge.h> 10 #include <torch/csrc/autograd/engine.h> 11 #include <torch/csrc/autograd/function.h> 12 #include <torch/csrc/autograd/functions/basic_ops.h> 13 #include <torch/csrc/autograd/python_anomaly_mode.h> 14 #include <torch/csrc/autograd/python_cpp_function.h> 15 #include <torch/csrc/autograd/python_function.h> 16 #include <torch/csrc/autograd/python_saved_variable_hooks.h> 17 #include <torch/csrc/utils/pybind.h> 18 #include <torch/csrc/utils/pycfunction_helpers.h> 19 20 #ifndef _WIN32 21 #include <pthread.h> 22 #endif 23 24 #include <memory> // for unique_ptr 25 #include <utility> 26 27 using namespace torch::autograd; 28 29 struct THPEngine { 30 PyObject_HEAD 31 }; 32 33 static bool _reinitialize_engine = false; 34 35 namespace torch::autograd::python { 36 37 PythonEngine::PythonEngine() = default; 38 get_python_engine()39 Engine& PythonEngine::get_python_engine() { 40 static PythonEngine engine; 41 // This is "probably" thread-safe because the flag is set in a fork handler 42 // before any threads are created, and this function is only called with the 43 // GIL held. However, using fork + threads is playing with fire so this is 44 // more of a "best effort" thing. For example, if the fork occurs while the 45 // backwards threads hold a lock, we'll probably deadlock in the engine 46 // destructor. 47 if (_reinitialize_engine) { 48 engine.release_workers(); 49 engine.~PythonEngine(); 50 new (&engine) torch::autograd::python::PythonEngine(); 51 _reinitialize_engine = false; 52 } 53 return engine; 54 } 55 ~PythonEngine()56 PythonEngine::~PythonEngine() { 57 Engine::stop(); 58 } 59 60 #if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9 61 #define IS_PYTHON_3_9_PLUS 62 #endif 63 thread_init(int device,const std::shared_ptr<ReadyQueue> & ready_queue,bool should_increment)64 void PythonEngine::thread_init( 65 int device, 66 const std::shared_ptr<ReadyQueue>& ready_queue, 67 bool should_increment) { 68 // Increment thread usage count before acquiring the GIL 69 if (should_increment) { 70 increment_non_reentrant_thread_count(); 71 } 72 // Create a PyThreadState, but release the GIL. This lets 73 // pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL 74 // without having to create a new PyThreadState each time. 75 #if defined(IS_PYTHON_3_9_PLUS) 76 auto gil = std::make_unique<pybind11::gil_scoped_acquire>(); 77 #else 78 pybind11::gil_scoped_acquire gil; 79 #endif 80 pybind11::gil_scoped_release no_gil; 81 Engine::thread_init(device, ready_queue, false); 82 83 if (should_increment) { 84 // Decrement the count during shutdown if we incremented earlier. 85 decrement_non_reentrant_thread_count(); 86 } 87 88 #if defined(IS_PYTHON_3_9_PLUS) 89 // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if 90 // runtime is finalizing 91 if (!Py_IsInitialized()) { 92 no_gil.disarm(); 93 // TODO: call disarm once PyThreadState_Clear can safely be called from 94 // finalize NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct 95 // PyThreadState, so avoid use-after-free here. 96 auto ptr = gil.release(); 97 operator delete(ptr); 98 } 99 #endif 100 } 101 thread_on_exception(const std::shared_ptr<GraphTask> & graph_task,const std::shared_ptr<Node> & fn,std::exception & e)102 void PythonEngine::thread_on_exception( 103 const std::shared_ptr<GraphTask>& graph_task, 104 const std::shared_ptr<Node>& fn, 105 std::exception& e) { 106 // See Note [ Persisting PyErr state across autograd engine threads ] 107 auto python_err = dynamic_cast<python_error*>(&e); 108 if (python_err) { 109 python_err->persist(); 110 } 111 Engine::thread_on_exception(graph_task, fn, e); 112 } 113 make_anomaly_metadata()114 std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() { 115 return std::make_unique<PyAnomalyMetadata>(); 116 } 117 118 std::unique_ptr<SavedVariableHooks> PythonEngine:: get_default_saved_variable_hooks()119 get_default_saved_variable_hooks() { 120 return PyDefaultSavedVariableHooks::get_hooks(); 121 } 122 execute(const edge_list & roots,const variable_list & inputs,bool keep_graph,bool create_graph,bool accumulate_grad,const edge_list & outputs)123 variable_list PythonEngine::execute( 124 const edge_list& roots, 125 const variable_list& inputs, 126 bool keep_graph, 127 bool create_graph, 128 bool accumulate_grad, 129 const edge_list& outputs) { 130 TORCH_CHECK( 131 !PyGILState_Check(), 132 "The autograd engine was called while holding the GIL. If you are using the C++ " 133 "API, the autograd engine is an expensive operation that does not require the " 134 "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'" 135 ". If you are not using the C++ API, please report a bug to the pytorch team.") 136 try { 137 return Engine::execute( 138 roots, inputs, keep_graph, create_graph, accumulate_grad, outputs); 139 } catch (python_error& e) { 140 e.restore(); 141 throw; 142 } 143 } 144 execute_with_graph_task(const std::shared_ptr<GraphTask> & graph_task,std::shared_ptr<Node> graph_root,InputBuffer && input_buffer)145 c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task( 146 const std::shared_ptr<GraphTask>& graph_task, 147 std::shared_ptr<Node> graph_root, 148 InputBuffer&& input_buffer) { 149 try { 150 return Engine::execute_with_graph_task( 151 graph_task, std::move(graph_root), std::move(input_buffer)); 152 } catch (python_error& e) { 153 pybind11::gil_scoped_acquire gil; 154 if (!PyErr_Occurred()) { 155 // Set the error indicator only if it is not set already. 156 e.restore(); 157 } 158 throw; 159 } 160 } 161 } // namespace torch::autograd::python 162 163 PyObject* THPEngineClass = nullptr; 164 parseGradientEdge(PyObject * obj,int64_t index)165 inline static Edge parseGradientEdge(PyObject* obj, int64_t index) { 166 PyObject* grad_fn = PyTuple_GetItem(obj, 0); 167 auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1)); 168 std::shared_ptr<torch::autograd::Node> grad_fn_sp; 169 if (THPFunction_Check(grad_fn)) { 170 grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock(); 171 } else if (THPCppFunction_Check(grad_fn)) { 172 grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata; 173 } else { 174 TORCH_CHECK( 175 false, 176 "GradientEdge's first object must be an autograd.graph.Node " 177 "but got ", 178 THPUtils_typename(grad_fn)); 179 } 180 return Edge(grad_fn_sp, output_nr); 181 } 182 183 // Implementation of torch._C._EngineBase.run_backward THPEngine_run_backward(PyObject * self,PyObject * args,PyObject * kwargs)184 PyObject* THPEngine_run_backward( 185 PyObject* self, 186 PyObject* args, 187 PyObject* kwargs) { 188 HANDLE_TH_ERRORS 189 PyObject* tensors = nullptr; 190 PyObject* grad_tensors = nullptr; 191 unsigned char keep_graph = 0; 192 unsigned char create_graph = 0; 193 PyObject* inputs = nullptr; 194 unsigned char allow_unreachable = 0; 195 unsigned char accumulate_grad = 196 0; // Indicate whether to accumulate grad into leaf Tensors or capture 197 constexpr const char* accepted_kwargs[] = {// NOLINT 198 "tensors", 199 "grad_tensors", 200 "keep_graph", 201 "create_graph", 202 "inputs", 203 "allow_unreachable", 204 "accumulate_grad", 205 nullptr}; 206 if (!PyArg_ParseTupleAndKeywords( 207 args, 208 kwargs, 209 "OObb|Obb", 210 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,-warnings-as-errors) 211 const_cast<char**>(accepted_kwargs), 212 &tensors, 213 &grad_tensors, 214 &keep_graph, 215 &create_graph, 216 &inputs, 217 &allow_unreachable, 218 &accumulate_grad)) 219 return nullptr; 220 TORCH_CHECK( 221 PyTuple_Check(tensors), 222 "tensors argument is expected to " 223 "be a tuple, but got ", 224 THPUtils_typename(tensors)); 225 TORCH_CHECK( 226 PyTuple_Check(grad_tensors), 227 "grad_tensors argument is " 228 "expected to be a tuple, but got ", 229 THPUtils_typename(grad_tensors)); 230 231 Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors); 232 Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors); 233 TORCH_CHECK( 234 num_tensors == num_gradients, 235 "got ", 236 num_tensors, 237 " tensors and ", 238 num_gradients, 239 " gradients"); 240 241 // The user either called autograd.backward(...) or autograd.grad(...) to get 242 // here 243 bool backward_api_called = accumulate_grad; 244 TORCH_CHECK( 245 !backward_api_called || at::impl::VmapMode::current_vmap_level() == 0, 246 "backward() called inside torch.vmap. This is not supported, " 247 "please call backward() outside torch.vmap or instead use " 248 "torch.autograd.grad inside torch.vmap"); 249 250 edge_list roots; 251 roots.reserve(num_tensors); 252 variable_list grads; 253 grads.reserve(num_tensors); 254 for (const auto i : c10::irange(num_tensors)) { 255 PyObject* _tensor = PyTuple_GET_ITEM(tensors, i); 256 Edge gradient_edge; // Temporary variable to hold the gradient edge 257 std::optional<at::Tensor> mb_output; 258 if (THPVariable_Check(_tensor)) { 259 mb_output = THPVariable_Unpack(_tensor); 260 TORCH_CHECK( 261 !isBatchedTensor(mb_output.value()), 262 "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", 263 "torch.vmap. We do not support the case where any outputs are ", 264 "vmapped tensors (output ", 265 i, 266 " is being vmapped over). Please " 267 "call autograd.grad() outside torch.vmap or file a bug report " 268 "with your use case."); 269 gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value()); 270 } else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) { 271 gradient_edge = parseGradientEdge(_tensor, i); 272 } else { 273 TORCH_CHECK( 274 false, 275 "element ", 276 i, 277 " of tensors tuple is neither a Tensor nor a GradientEdge"); 278 } 279 TORCH_CHECK( 280 gradient_edge.function, 281 "element ", 282 i, 283 " of tensors does not require grad and does not have a grad_fn"); 284 roots.push_back(std::move(gradient_edge)); 285 286 PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i); 287 if (THPVariable_Check(grad)) { 288 const Variable& grad_var = THPVariable_Unpack(grad); 289 if (grad_var.has_names()) { 290 TORCH_WARN( 291 "Autograd was passed a named grad tensor with dims ", 292 grad_var.names(), 293 ". Autograd does not yet support named tensor semantics, so all names ", 294 "will be ignored. In practice all computed gradients will still be correct " 295 "according to regular tensor semantics."); 296 } 297 grads.push_back(grad_var); 298 } else { 299 TORCH_CHECK( 300 grad == Py_None, 301 "element ", 302 i, 303 " of gradients tuple is not a Tensor or None"); 304 TORCH_CHECK( 305 mb_output.has_value(), 306 "element ", 307 i, 308 " of gradients tuple is None, but the corresponding output is a GradientEdge." 309 "This is not supported."); 310 TORCH_CHECK( 311 !mb_output.value().requires_grad(), 312 "element ", 313 i, 314 " of gradients tuple is None, but the corresponding Tensor requires grad"); 315 } 316 } 317 318 std::vector<Edge> output_edges; 319 if (inputs != nullptr) { 320 TORCH_CHECK( 321 PyTuple_CheckExact(inputs), "inputs to run_backward must be a tuple"); 322 int num_inputs = PyTuple_GET_SIZE(inputs); 323 output_edges.reserve(num_inputs); 324 for (const auto i : c10::irange(num_inputs)) { 325 PyObject* input = PyTuple_GET_ITEM(inputs, i); 326 if (THPVariable_Check(input)) { 327 const auto& tensor = THPVariable_Unpack(input); 328 TORCH_CHECK( 329 !isBatchedTensor(tensor), 330 "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", 331 "torch.vmap. We do not support the case where any inputs are ", 332 "vmapped tensors (input ", 333 i, 334 " is being vmapped over). Please " 335 "call autograd.grad() outside torch.vmap or file a bug report " 336 "with your use case.") 337 const auto output_nr = tensor.output_nr(); 338 auto grad_fn = tensor.grad_fn(); 339 if (!grad_fn) { 340 grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor); 341 } 342 if (accumulate_grad) { 343 tensor.retain_grad(); 344 } 345 TORCH_CHECK( 346 tensor.requires_grad(), 347 "One of the differentiated Tensors does not require grad"); 348 if (!grad_fn) { 349 // NOTE [ Autograd Unreachable Input ] 350 // Since input has no grad_accumulator, its guaranteed to be 351 // unreachable. We initialize an edge pointing to a non-nullptr Node 352 // so nodes in the graph (e.g., mul when an operand is scalar) that 353 // have edges pointing to nullptr don't get erroneously assigned 354 // `needed = True` in exec_info. 355 output_edges.emplace_back(std::make_shared<Identity>(), 0); 356 } else { 357 output_edges.emplace_back(grad_fn, output_nr); 358 } 359 } else if (PyObject_IsInstance(input, THPGradientEdgeClass)) { 360 output_edges.emplace_back(parseGradientEdge(input, i)); 361 } else { 362 TORCH_CHECK( 363 false, 364 "all inputs have to be Tensors or GradientEdges, but got ", 365 THPUtils_typename(input)); 366 } 367 } 368 } 369 370 variable_list outputs; 371 { 372 pybind11::gil_scoped_release no_gil; 373 auto& engine = python::PythonEngine::get_python_engine(); 374 outputs = engine.execute( 375 roots, grads, keep_graph, create_graph, accumulate_grad, output_edges); 376 } 377 378 if (!backward_api_called && inputs != nullptr) { 379 int num_inputs = PyTuple_GET_SIZE(inputs); 380 THPObjectPtr py_outputs{PyTuple_New(num_inputs)}; 381 if (!py_outputs) 382 return nullptr; 383 for (const auto i : c10::irange(num_inputs)) { 384 TORCH_CHECK( 385 allow_unreachable || outputs[i].defined(), 386 "One of the " 387 "differentiated Tensors appears to not have been used " 388 "in the graph. Set allow_unused=True if this is the " 389 "desired behavior."); 390 PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i])); 391 } 392 return py_outputs.release(); 393 } else { 394 Py_RETURN_NONE; 395 } 396 END_HANDLE_TH_ERRORS 397 } 398 THPEngine_queue_callback(PyObject * self,PyObject * _callback)399 PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) { 400 HANDLE_TH_ERRORS 401 auto& engine = python::PythonEngine::get_python_engine(); 402 std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) { 403 pybind11::gil_scoped_acquire gil; 404 Py_DECREF(obj); 405 }); 406 Py_INCREF(_callback); 407 engine.queue_callback([callback]() { 408 pybind11::gil_scoped_acquire gil; 409 THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)}; 410 if (!result) { 411 // Note [ Persisting PyErr state across autograd engine threads ] 412 // 413 // Since the autograd engine is multi-threaded, and Python error state is 414 // local to each thread, it must preserve the python error from the worker 415 // thread and rethrow it as-is in the calling thread. This is done via 416 // persisting the error in the two places that can encounter Python 417 // errors: (1) evaluate function and (2) queued callbacks. 418 // 419 // TODO: the engine is not actually responsible for persisting the error 420 // in the custom autograd Function case today! See the note above 421 // `raise_python_error()` function in python_function.cpp and 422 // python_hooks.cpp for more details. Persisting an extra time in the 423 // engine is fine because doing so is a no-op when the python_error has 424 // already been persisted. 425 python_error err; 426 err.persist(); 427 throw std::move(err); 428 } 429 }); 430 Py_RETURN_NONE; 431 END_HANDLE_TH_ERRORS 432 } 433 THPEngine_is_checkpoint_valid(PyObject * self,PyObject * noargs)434 PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) { 435 HANDLE_TH_ERRORS 436 auto& engine = python::PythonEngine::get_python_engine(); 437 if (engine.is_checkpoint_valid()) { 438 Py_RETURN_TRUE; 439 } else { 440 Py_RETURN_FALSE; 441 } 442 END_HANDLE_TH_ERRORS 443 } 444 THPEngine_new(PyTypeObject * type,PyObject * args,PyObject * kwargs)445 PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) { 446 return type->tp_alloc(type, 0); 447 } 448 449 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) 450 static struct PyMethodDef THPEngine_methods[] = { 451 {(char*)"run_backward", 452 castPyCFunctionWithKeywords(THPEngine_run_backward), 453 METH_VARARGS | METH_KEYWORDS, 454 nullptr}, 455 {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr}, 456 {(char*)"is_checkpoint_valid", 457 THPEngine_is_checkpoint_valid, 458 METH_NOARGS, 459 nullptr}, 460 {nullptr}}; 461 462 PyTypeObject THPEngineType = { 463 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */ 464 sizeof(THPEngine), /* tp_basicsize */ 465 0, /* tp_itemsize */ 466 nullptr, /* tp_dealloc */ 467 0, /* tp_vectorcall_offset */ 468 nullptr, /* tp_getattr */ 469 nullptr, /* tp_setattr */ 470 nullptr, /* tp_reserved */ 471 nullptr, /* tp_repr */ 472 nullptr, /* tp_as_number */ 473 nullptr, /* tp_as_sequence */ 474 nullptr, /* tp_as_mapping */ 475 nullptr, /* tp_hash */ 476 nullptr, /* tp_call */ 477 nullptr, /* tp_str */ 478 nullptr, /* tp_getattro */ 479 nullptr, /* tp_setattro */ 480 nullptr, /* tp_as_buffer */ 481 // NOLINTNEXTLINE(misc-redundant-expression) 482 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ 483 nullptr, /* tp_doc */ 484 nullptr, /* tp_traverse */ 485 nullptr, /* tp_clear */ 486 nullptr, /* tp_richcompare */ 487 0, /* tp_weaklistoffset */ 488 nullptr, /* tp_iter */ 489 nullptr, /* tp_iternext */ 490 THPEngine_methods, /* tp_methods */ 491 nullptr, /* tp_members */ 492 nullptr, /* tp_getset */ 493 nullptr, /* tp_base */ 494 nullptr, /* tp_dict */ 495 nullptr, /* tp_descr_get */ 496 nullptr, /* tp_descr_set */ 497 0, /* tp_dictoffset */ 498 nullptr, /* tp_init */ 499 nullptr, /* tp_alloc */ 500 THPEngine_new /* tp_new */ 501 }; 502 child_atfork()503 static void child_atfork() { 504 _reinitialize_engine = true; 505 } 506 THPEngine_initModule(PyObject * module)507 bool THPEngine_initModule(PyObject* module) { 508 #ifndef _WIN32 509 if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) { 510 throw std::runtime_error("unable to set pthread_atfork handler"); 511 } 512 #endif 513 if (PyType_Ready(&THPEngineType) < 0) 514 return false; 515 Py_INCREF(&THPEngineType); 516 PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType); 517 set_default_engine_stub(python::PythonEngine::get_python_engine); 518 return true; 519 } 520