xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseFactories.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Dispatch.h>
2 #include <ATen/TensorIterator.h>
3 #include <ATen/native/sparse/SparseFactories.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_spdiags_native.h>
10 #include <ATen/ops/_unique.h>
11 #include <ATen/ops/arange.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/sparse_coo_tensor.h>
14 #include <ATen/ops/where.h>
15 #endif
16 
17 namespace at::native {
18 
19 DEFINE_DISPATCH(spdiags_kernel_stub);
20 
spdiags(const Tensor & diagonals,const Tensor & offsets,IntArrayRef shape,std::optional<Layout> layout)21 Tensor spdiags(
22     const Tensor& diagonals,
23     const Tensor& offsets,
24     IntArrayRef shape,
25     std::optional<Layout> layout) {
26   auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals;
27   TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix");
28   TORCH_CHECK(shape.size() == 2, "Output shape must be 2d");
29   auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets;
30   TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector");
31   TORCH_CHECK(
32       diagonals_2d.size(0) == offsets_1d.size(0),
33       "Number of diagonals (",
34       diagonals_2d.size(0),
35       ") does not match the number of offsets (",
36       offsets_1d.size(0),
37       ")");
38   if (layout) {
39     TORCH_CHECK(
40         (*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) ||
41             (*layout == Layout::SparseCsr),
42         "Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ",
43         *layout);
44   }
45   TORCH_CHECK(
46       offsets_1d.scalar_type() == at::kLong,
47       "Offset Tensor must have dtype Long but got ",
48       offsets_1d.scalar_type());
49 
50   TORCH_CHECK(
51       offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(),
52       "Offset tensor contains duplicate values");
53 
54   auto nnz_per_diag = at::where(
55       offsets_1d.le(0),
56       offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)),
57       offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg());
58 
59   auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1);
60   const auto nnz = diagonals_2d.size(0) > 0
61       ? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>()
62       : int64_t{0};
63   // Offsets into nnz for each diagonal
64   auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag);
65   // coo tensor guts
66   auto indices = at::empty({2, nnz}, offsets_1d.options());
67   auto values = at::empty({nnz}, diagonals_2d.options());
68   // We add this indexer to lookup the row of diagonals we are reading from at
69   // each iteration
70   const auto n_diag = offsets_1d.size(0);
71   Tensor diag_index = at::arange(n_diag, offsets_1d.options());
72   // cpu_kernel requires an output
73   auto dummy = at::empty({1}, offsets_1d.options()).resize_({0});
74   auto iter = TensorIteratorConfig()
75                   .set_check_mem_overlap(false)
76                   .add_output(dummy)
77                   .add_input(diag_index)
78                   .add_input(offsets_1d)
79                   .add_input(result_mem_offsets)
80                   .add_input(nnz_per_diag)
81                   .build();
82   spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices);
83   auto result_coo = at::sparse_coo_tensor(indices, values, shape);
84   if (layout) {
85     if (*layout == Layout::SparseCsr) {
86       return result_coo.to_sparse_csr();
87     }
88     if (*layout == Layout::SparseCsc) {
89       return result_coo.to_sparse_csc();
90     }
91   }
92   return result_coo;
93 }
94 
95 } // namespace at::native
96