1 #include <ATen/Dispatch.h>
2 #include <ATen/TensorIterator.h>
3 #include <ATen/native/sparse/SparseFactories.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_spdiags_native.h>
10 #include <ATen/ops/_unique.h>
11 #include <ATen/ops/arange.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/sparse_coo_tensor.h>
14 #include <ATen/ops/where.h>
15 #endif
16
17 namespace at::native {
18
19 DEFINE_DISPATCH(spdiags_kernel_stub);
20
spdiags(const Tensor & diagonals,const Tensor & offsets,IntArrayRef shape,std::optional<Layout> layout)21 Tensor spdiags(
22 const Tensor& diagonals,
23 const Tensor& offsets,
24 IntArrayRef shape,
25 std::optional<Layout> layout) {
26 auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals;
27 TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix");
28 TORCH_CHECK(shape.size() == 2, "Output shape must be 2d");
29 auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets;
30 TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector");
31 TORCH_CHECK(
32 diagonals_2d.size(0) == offsets_1d.size(0),
33 "Number of diagonals (",
34 diagonals_2d.size(0),
35 ") does not match the number of offsets (",
36 offsets_1d.size(0),
37 ")");
38 if (layout) {
39 TORCH_CHECK(
40 (*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) ||
41 (*layout == Layout::SparseCsr),
42 "Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ",
43 *layout);
44 }
45 TORCH_CHECK(
46 offsets_1d.scalar_type() == at::kLong,
47 "Offset Tensor must have dtype Long but got ",
48 offsets_1d.scalar_type());
49
50 TORCH_CHECK(
51 offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(),
52 "Offset tensor contains duplicate values");
53
54 auto nnz_per_diag = at::where(
55 offsets_1d.le(0),
56 offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)),
57 offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg());
58
59 auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1);
60 const auto nnz = diagonals_2d.size(0) > 0
61 ? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>()
62 : int64_t{0};
63 // Offsets into nnz for each diagonal
64 auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag);
65 // coo tensor guts
66 auto indices = at::empty({2, nnz}, offsets_1d.options());
67 auto values = at::empty({nnz}, diagonals_2d.options());
68 // We add this indexer to lookup the row of diagonals we are reading from at
69 // each iteration
70 const auto n_diag = offsets_1d.size(0);
71 Tensor diag_index = at::arange(n_diag, offsets_1d.options());
72 // cpu_kernel requires an output
73 auto dummy = at::empty({1}, offsets_1d.options()).resize_({0});
74 auto iter = TensorIteratorConfig()
75 .set_check_mem_overlap(false)
76 .add_output(dummy)
77 .add_input(diag_index)
78 .add_input(offsets_1d)
79 .add_input(result_mem_offsets)
80 .add_input(nnz_per_diag)
81 .build();
82 spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices);
83 auto result_coo = at::sparse_coo_tensor(indices, values, shape);
84 if (layout) {
85 if (*layout == Layout::SparseCsr) {
86 return result_coo.to_sparse_csr();
87 }
88 if (*layout == Layout::SparseCsc) {
89 return result_coo.to_sparse_csc();
90 }
91 }
92 return result_coo;
93 }
94
95 } // namespace at::native
96