xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Sorting.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Sorting.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/NamedTensor.h>
5 #include <ATen/Context.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/MemoryOverlap.h>
8 #include <ATen/WrapDimUtils.h>
9 #include <ATen/cuda/CUDAContext.h>
10 #include <ATen/cuda/detail/TensorInfo.cuh>
11 
12 #include <ATen/native/SortingUtils.h>
13 #include <ATen/native/ReduceOpsUtils.h>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/full.h>
20 #include <ATen/ops/kthvalue_native.h>
21 #include <ATen/ops/median_native.h>
22 #include <ATen/ops/nanmedian_native.h>
23 #include <ATen/ops/where.h>
24 #include <ATen/ops/rsub.h>
25 #include <ATen/ops/div.h>
26 #include <ATen/ops/index.h>
27 #endif
28 
29 namespace at::native {
30 namespace {
31 
kthvalue_out_impl_cuda(Tensor & values,Tensor & indices,const Tensor & self,int64_t k,int64_t dim_,bool keepdim)32 std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
33     Tensor& values,
34     Tensor& indices,
35     const Tensor& self,
36     int64_t k,
37     int64_t dim_,
38     bool keepdim) {
39   int64_t dim = maybe_wrap_dim(dim_, self.dim());
40   int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
41   zero_numel_check_dims(self, dim, "kthvalue()");
42 
43   TORCH_CHECK(k >= 1 && k <= slicesize,
44               "kthvalue(): selected number k out of range for dimension ", dim);
45 
46   at::assert_no_overlap(self, values);
47 
48   _reduction_with_indices_allocate_or_resize_output(
49       values, indices, self, dim, keepdim);
50   if (self.dim() == 0 && self.numel() == 1) {
51     values.copy_(self);
52     indices.zero_();
53     return std::forward_as_tuple(values, indices);
54   }
55 
56   TORCH_CHECK(
57       self.dim() <= MAX_TENSORINFO_DIMS,
58       "cannot operate on more than ",
59       MAX_TENSORINFO_DIMS,
60       " dimensions");
61 
62   // Based on required index size, run the algorithm with the
63   // appropriate index type
64   if (self.numel() != 0) {
65     launch_kthvalue_kernel(values, indices, self, dim, k);
66   }
67 
68   if (!keepdim) {
69     values.squeeze_(dim);
70     indices.squeeze_(dim);
71   }
72   return std::forward_as_tuple(values, indices);
73 }
74 
median_with_indices_impl(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim,bool keepdim,bool ignore_nan)75 std::tuple<Tensor&, Tensor&> median_with_indices_impl(
76     Tensor& values,
77     Tensor& indices,
78     const Tensor& self,
79     int64_t dim,
80     bool keepdim,
81     bool ignore_nan) {
82   // See note [Writing Nondeterministic Operations]
83   // If there are duplicate elements of a median value, the procedure for choosing which
84   // of the duplicates to use for the indices output is nondeterministic.
85   at::globalContext().alertNotDeterministic("median CUDA with indices output");
86   NoNamesGuard guard;
87 
88   dim = at::maybe_wrap_dim(dim, self.dim());
89   Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0);
90 
91   checkDeviceType("median", {values, indices}, self.device().type());
92   checkScalarType("median", {indices, "indices", 1}, kLong);
93   checkSameType("median", {values, "values", 0}, {self, "self", 2});
94 
95   TORCH_CHECK(
96       self.dim() <= MAX_TENSORINFO_DIMS,
97       "median() cannot operate on more than ",
98       MAX_TENSORINFO_DIMS,
99       " dimensions");
100 
101   std::vector<int64_t> out_shape = self.sizes().vec();
102   zero_numel_check_dims(self, dim, "median()");
103   if (self.dim() > 0) {
104     assert(dim >= 0);
105     assert(dim < static_cast<int64_t>(out_shape.size()));
106 
107     if (keepdim) {
108       out_shape[dim] = 1;
109     } else {
110       out_shape.erase(out_shape.begin() + dim);
111     }
112   }
113 
114   values.resize_(out_shape);
115   indices.resize_(out_shape);
116 
117   // Only launch kernel for non-empty tensors
118   if (self.numel() > 0) {
119     // Ensure #dim is the same for all tensors required for reduction
120     Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim);
121     Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim);
122 
123     launch_median_kernel(vals, inds, in, dim, ignore_nan);
124   }
125 
126   guard.reset();
127   namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
128   namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
129 
130   return std::forward_as_tuple(values, indices);
131 }
132 
median_impl(const Tensor & self,bool ignore_nan)133 Tensor median_impl(const Tensor& self, bool ignore_nan) {
134   NoNamesGuard guard;
135 
136   int64_t size = self.numel();
137   // Return nan for empty tensors
138   if (size <= 0) {
139     return at::full({}, std::numeric_limits<float>::quiet_NaN()).to(self.options());
140   }
141 
142   // Sort input tensor to efficiently query for median element
143   Tensor sorted = std::get<0>(self.flatten().sort());
144 
145   if (!ignore_nan) {
146     // For torch.median return either the middle element or nan (sorted as
147     // largest) if there are any
148     int64_t k = (size - 1) / 2;
149     return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]);
150   } else {
151     // For torch.nanmedian return the middle element among the non-nan values
152     Tensor k = at::div(at::rsub(sorted.isnan().sum(), (size - 1)), 2).to(kLong);
153     return at::index(sorted, {k});
154   }
155 }
156 
157 } // namespace (anonymous)
158 
kthvalue_out_cuda(const Tensor & self,int64_t k,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)159 std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
160     const Tensor& self,
161     int64_t k,
162     int64_t dim,
163     bool keepdim,
164     Tensor& values,
165     Tensor& indices) {
166   // See note [Writing Nondeterministic Operations]
167   // If there are duplicate elements of the kth value, the procedure for choosing which
168   // of the duplicates to use for the indices output is nondeterministic.
169   at::globalContext().alertNotDeterministic("kthvalue CUDA");
170   auto result = [&]() {
171     NoNamesGuard guard;
172     // `kthvalue_out_impl_cuda` expects contiguous in input `self`.
173     return kthvalue_out_impl_cuda(values, indices, self.contiguous(), k, dim, keepdim);
174   }();
175   namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
176   namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
177   return result;
178 }
179 
180 // Mark: median
181 
median_out_cuda(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)182 std::tuple<Tensor&, Tensor&> median_out_cuda(
183     const Tensor& self,
184     int64_t dim,
185     bool keepdim,
186     Tensor& values,
187     Tensor& indices) {
188   return median_with_indices_impl(
189       values, indices, self, dim, keepdim, /*ignore_nan=*/false);
190 }
191 
median_cuda(const Tensor & self)192 Tensor median_cuda(const Tensor& self) {
193   return median_impl(self, /*ignore_nan=*/false);
194 }
195 
nanmedian_out_cuda(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)196 std::tuple<Tensor&, Tensor&> nanmedian_out_cuda(
197     const Tensor& self,
198     int64_t dim,
199     bool keepdim,
200     Tensor& values,
201     Tensor& indices) {
202   return median_with_indices_impl(
203       values, indices, self, dim, keepdim, /*ignore_nan=*/true);
204 }
205 
nanmedian_cuda(const Tensor & self)206 Tensor nanmedian_cuda(const Tensor& self) {
207   return median_impl(self, /*ignore_nan=*/true);
208 }
209 
210 } // namespace at::native
211