1 #include <torch/csrc/utils/nested.h>
2 #include <torch/csrc/utils/pycfunction_helpers.h>
3 #include <torch/csrc/utils/python_arg_parser.h>
4 #include <torch/torch.h>
5
6 namespace torch::autograd {
7
THPVariable_nested_tensor(PyObject *,PyObject * args,PyObject * kwargs)8 static PyObject* THPVariable_nested_tensor(
9 PyObject* /*self*/,
10 PyObject* args,
11 PyObject* kwargs) {
12 HANDLE_TH_ERRORS
13 static PythonArgParser parser({
14 "nested_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
15 });
16
17 constexpr int ctor_num_args = 5;
18 ParsedArgs<ctor_num_args> parsed_args;
19 auto r = parser.parse(args, kwargs, parsed_args);
20
21 jit::tracer::warn(
22 "torch.nested.nested_tensor", jit::tracer::WARN_CONSTRUCTOR);
23 return THPVariable_Wrap(torch::utils::nested_tensor_ctor(
24 torch::tensors::get_default_dispatch_key(),
25 torch::tensors::get_default_scalar_type(),
26 r));
27 END_HANDLE_TH_ERRORS
28 }
29
30 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
31 static PyMethodDef nested_functions_manual[] = {
32 {"nested_tensor",
33 castPyCFunctionWithKeywords(THPVariable_nested_tensor),
34 METH_VARARGS | METH_KEYWORDS,
35 nullptr},
36 };
37
get_nested_functions_manual()38 PyMethodDef* get_nested_functions_manual() {
39 return nested_functions_manual;
40 }
41
42 } // namespace torch::autograd
43