1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Sorting.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/NamedTensor.h>
5 #include <ATen/Context.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/MemoryOverlap.h>
8 #include <ATen/WrapDimUtils.h>
9 #include <ATen/cuda/CUDAContext.h>
10 #include <ATen/cuda/detail/TensorInfo.cuh>
11
12 #include <ATen/native/SortingUtils.h>
13 #include <ATen/native/ReduceOpsUtils.h>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/full.h>
20 #include <ATen/ops/kthvalue_native.h>
21 #include <ATen/ops/median_native.h>
22 #include <ATen/ops/nanmedian_native.h>
23 #include <ATen/ops/where.h>
24 #include <ATen/ops/rsub.h>
25 #include <ATen/ops/div.h>
26 #include <ATen/ops/index.h>
27 #endif
28
29 namespace at::native {
30 namespace {
31
kthvalue_out_impl_cuda(Tensor & values,Tensor & indices,const Tensor & self,int64_t k,int64_t dim_,bool keepdim)32 std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
33 Tensor& values,
34 Tensor& indices,
35 const Tensor& self,
36 int64_t k,
37 int64_t dim_,
38 bool keepdim) {
39 int64_t dim = maybe_wrap_dim(dim_, self.dim());
40 int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
41 zero_numel_check_dims(self, dim, "kthvalue()");
42
43 TORCH_CHECK(k >= 1 && k <= slicesize,
44 "kthvalue(): selected number k out of range for dimension ", dim);
45
46 at::assert_no_overlap(self, values);
47
48 _reduction_with_indices_allocate_or_resize_output(
49 values, indices, self, dim, keepdim);
50 if (self.dim() == 0 && self.numel() == 1) {
51 values.copy_(self);
52 indices.zero_();
53 return std::forward_as_tuple(values, indices);
54 }
55
56 TORCH_CHECK(
57 self.dim() <= MAX_TENSORINFO_DIMS,
58 "cannot operate on more than ",
59 MAX_TENSORINFO_DIMS,
60 " dimensions");
61
62 // Based on required index size, run the algorithm with the
63 // appropriate index type
64 if (self.numel() != 0) {
65 launch_kthvalue_kernel(values, indices, self, dim, k);
66 }
67
68 if (!keepdim) {
69 values.squeeze_(dim);
70 indices.squeeze_(dim);
71 }
72 return std::forward_as_tuple(values, indices);
73 }
74
median_with_indices_impl(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim,bool keepdim,bool ignore_nan)75 std::tuple<Tensor&, Tensor&> median_with_indices_impl(
76 Tensor& values,
77 Tensor& indices,
78 const Tensor& self,
79 int64_t dim,
80 bool keepdim,
81 bool ignore_nan) {
82 // See note [Writing Nondeterministic Operations]
83 // If there are duplicate elements of a median value, the procedure for choosing which
84 // of the duplicates to use for the indices output is nondeterministic.
85 at::globalContext().alertNotDeterministic("median CUDA with indices output");
86 NoNamesGuard guard;
87
88 dim = at::maybe_wrap_dim(dim, self.dim());
89 Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0);
90
91 checkDeviceType("median", {values, indices}, self.device().type());
92 checkScalarType("median", {indices, "indices", 1}, kLong);
93 checkSameType("median", {values, "values", 0}, {self, "self", 2});
94
95 TORCH_CHECK(
96 self.dim() <= MAX_TENSORINFO_DIMS,
97 "median() cannot operate on more than ",
98 MAX_TENSORINFO_DIMS,
99 " dimensions");
100
101 std::vector<int64_t> out_shape = self.sizes().vec();
102 zero_numel_check_dims(self, dim, "median()");
103 if (self.dim() > 0) {
104 assert(dim >= 0);
105 assert(dim < static_cast<int64_t>(out_shape.size()));
106
107 if (keepdim) {
108 out_shape[dim] = 1;
109 } else {
110 out_shape.erase(out_shape.begin() + dim);
111 }
112 }
113
114 values.resize_(out_shape);
115 indices.resize_(out_shape);
116
117 // Only launch kernel for non-empty tensors
118 if (self.numel() > 0) {
119 // Ensure #dim is the same for all tensors required for reduction
120 Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim);
121 Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim);
122
123 launch_median_kernel(vals, inds, in, dim, ignore_nan);
124 }
125
126 guard.reset();
127 namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
128 namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
129
130 return std::forward_as_tuple(values, indices);
131 }
132
median_impl(const Tensor & self,bool ignore_nan)133 Tensor median_impl(const Tensor& self, bool ignore_nan) {
134 NoNamesGuard guard;
135
136 int64_t size = self.numel();
137 // Return nan for empty tensors
138 if (size <= 0) {
139 return at::full({}, std::numeric_limits<float>::quiet_NaN()).to(self.options());
140 }
141
142 // Sort input tensor to efficiently query for median element
143 Tensor sorted = std::get<0>(self.flatten().sort());
144
145 if (!ignore_nan) {
146 // For torch.median return either the middle element or nan (sorted as
147 // largest) if there are any
148 int64_t k = (size - 1) / 2;
149 return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]);
150 } else {
151 // For torch.nanmedian return the middle element among the non-nan values
152 Tensor k = at::div(at::rsub(sorted.isnan().sum(), (size - 1)), 2).to(kLong);
153 return at::index(sorted, {k});
154 }
155 }
156
157 } // namespace (anonymous)
158
kthvalue_out_cuda(const Tensor & self,int64_t k,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)159 std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
160 const Tensor& self,
161 int64_t k,
162 int64_t dim,
163 bool keepdim,
164 Tensor& values,
165 Tensor& indices) {
166 // See note [Writing Nondeterministic Operations]
167 // If there are duplicate elements of the kth value, the procedure for choosing which
168 // of the duplicates to use for the indices output is nondeterministic.
169 at::globalContext().alertNotDeterministic("kthvalue CUDA");
170 auto result = [&]() {
171 NoNamesGuard guard;
172 // `kthvalue_out_impl_cuda` expects contiguous in input `self`.
173 return kthvalue_out_impl_cuda(values, indices, self.contiguous(), k, dim, keepdim);
174 }();
175 namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
176 namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
177 return result;
178 }
179
180 // Mark: median
181
median_out_cuda(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)182 std::tuple<Tensor&, Tensor&> median_out_cuda(
183 const Tensor& self,
184 int64_t dim,
185 bool keepdim,
186 Tensor& values,
187 Tensor& indices) {
188 return median_with_indices_impl(
189 values, indices, self, dim, keepdim, /*ignore_nan=*/false);
190 }
191
median_cuda(const Tensor & self)192 Tensor median_cuda(const Tensor& self) {
193 return median_impl(self, /*ignore_nan=*/false);
194 }
195
nanmedian_out_cuda(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)196 std::tuple<Tensor&, Tensor&> nanmedian_out_cuda(
197 const Tensor& self,
198 int64_t dim,
199 bool keepdim,
200 Tensor& values,
201 Tensor& indices) {
202 return median_with_indices_impl(
203 values, indices, self, dim, keepdim, /*ignore_nan=*/true);
204 }
205
nanmedian_cuda(const Tensor & self)206 Tensor nanmedian_cuda(const Tensor& self) {
207 return median_impl(self, /*ignore_nan=*/true);
208 }
209
210 } // namespace at::native
211