xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_nested_functions_manual.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/nested.h>
2 #include <torch/csrc/utils/pycfunction_helpers.h>
3 #include <torch/csrc/utils/python_arg_parser.h>
4 #include <torch/torch.h>
5 
6 namespace torch::autograd {
7 
THPVariable_nested_tensor(PyObject *,PyObject * args,PyObject * kwargs)8 static PyObject* THPVariable_nested_tensor(
9     PyObject* /*self*/,
10     PyObject* args,
11     PyObject* kwargs) {
12   HANDLE_TH_ERRORS
13   static PythonArgParser parser({
14       "nested_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
15   });
16 
17   constexpr int ctor_num_args = 5;
18   ParsedArgs<ctor_num_args> parsed_args;
19   auto r = parser.parse(args, kwargs, parsed_args);
20 
21   jit::tracer::warn(
22       "torch.nested.nested_tensor", jit::tracer::WARN_CONSTRUCTOR);
23   return THPVariable_Wrap(torch::utils::nested_tensor_ctor(
24       torch::tensors::get_default_dispatch_key(),
25       torch::tensors::get_default_scalar_type(),
26       r));
27   END_HANDLE_TH_ERRORS
28 }
29 
30 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
31 static PyMethodDef nested_functions_manual[] = {
32     {"nested_tensor",
33      castPyCFunctionWithKeywords(THPVariable_nested_tensor),
34      METH_VARARGS | METH_KEYWORDS,
35      nullptr},
36 };
37 
get_nested_functions_manual()38 PyMethodDef* get_nested_functions_manual() {
39   return nested_functions_manual;
40 }
41 
42 } // namespace torch::autograd
43