1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/AdaptivePooling.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/cpu/vec/functional.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11 #include <ATen/OpMathType.h>
12
13 namespace at::native {
14
15 namespace {
16
17 template <typename scalar_t, typename accscalar_t>
cpu_adaptive_avg_pool2d(Tensor & output_,const Tensor & input_,IntArrayRef output_size)18 void cpu_adaptive_avg_pool2d(
19 Tensor& output_,
20 const Tensor& input_,
21 IntArrayRef output_size) {
22 auto input = input_.contiguous();
23 auto output = output_.contiguous();
24
25 auto input_data = input.const_data_ptr<scalar_t>();
26 auto output_data = output.data_ptr<scalar_t>();
27
28 int64_t ndim = input.ndimension();
29 // treat batch size and channels as one dimension
30 int64_t channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
31 int64_t input_height = input.size(-2);
32 int64_t input_width = input.size(-1);
33 int64_t output_height = output_size[0];
34 int64_t output_width = output_size[1];
35
36 // parallel on dim of N, C
37 at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
38 for (const auto c : c10::irange(begin, end)) {
39 const scalar_t* input_ptr = input_data + c * input_height * input_width;
40 scalar_t* output_ptr = output_data + c * output_height * output_width;
41
42 for (const auto oh : c10::irange(output_height)) {
43 int64_t ih0 = start_index(oh, output_height, input_height);
44 int64_t ih1 = end_index(oh, output_height, input_height);
45 int64_t kh = ih1 - ih0;
46
47 for (const auto ow : c10::irange(output_width)) {
48 int64_t iw0 = start_index(ow, output_width, input_width);
49 int64_t iw1 = end_index(ow, output_width, input_width);
50 int64_t kw = iw1 - iw0;
51
52 // compute local average
53 accscalar_t sum = 0;
54 for (const auto ih : c10::irange(ih0, ih1)) {
55 for (const auto iw : c10::irange(iw0, iw1)) {
56 sum += accscalar_t(input_ptr[ih * input_width + iw]);
57 }
58 }
59 output_ptr[oh * output_width + ow] = scalar_t(sum / kh / kw);
60 }
61 }
62 }
63 });
64
65 if (!output_.is_contiguous()) {
66 output_.copy_(output);
67 }
68 }
69
70 template <typename scalar_t>
71 typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool2d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)72 cpu_adaptive_avg_pool2d_channels_last(
73 Tensor& output_,
74 const Tensor& input_,
75 IntArrayRef output_size) {
76 auto memory_format = at::MemoryFormat::ChannelsLast;
77 auto input = input_.contiguous(memory_format);
78 auto output = output_.contiguous(memory_format);
79
80 auto input_data = input.const_data_ptr<scalar_t>();
81 auto output_data = output.data_ptr<scalar_t>();
82
83 int64_t nbatch = input.size(0);
84 int64_t channels = input.size(1);
85 int64_t input_height = input.size(2);
86 int64_t input_width = input.size(3);
87 int64_t output_height = output_size[0];
88 int64_t output_width = output_size[1];
89
90 using Vec = vec::Vectorized<scalar_t>;
91 // parallel on dim N, H, W
92 at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
93 int64_t n = 0;
94 int64_t oh = 0;
95 int64_t ow = 0;
96 data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
97
98 for (const auto i : c10::irange(begin, end)) {
99 int64_t ih0 = start_index(oh, output_height, input_height);
100 int64_t ih1 = end_index(oh, output_height, input_height);
101 int64_t kh = ih1 - ih0;
102
103 int64_t iw0 = start_index(ow, output_width, input_width);
104 int64_t iw1 = end_index(ow, output_width, input_width);
105 int64_t kw = iw1 - iw0;
106
107 scalar_t* out = output_data + i * channels;
108 int64_t size = channels;
109
110 // Note: For ordinary usage scenario, each out lane should
111 // fit in L1 cache; otherwise consider block dim C.
112 // Pass I: zero the out lane
113 int64_t d1 = 0;
114 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
115 Vec out_vec = Vec(scalar_t(0));
116 out_vec.store(out + d1);
117 }
118 for (; d1 < size; d1++) {
119 out[d1] = scalar_t(0);
120 }
121 // Pass II: compute local sum
122 for (const auto ih : c10::irange(ih0, ih1)) {
123 for (const auto iw : c10::irange(iw0, iw1)) {
124 const scalar_t* in = input_data + n * input_height * input_width * channels +
125 ih * input_width * channels + iw * channels;
126
127 int64_t d2 = 0;
128 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
129 Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2);
130 out_vec.store(out + d2);
131 }
132 for (; d2 < size; d2++) {
133 out[d2] += in[d2];
134 }
135 }
136 }
137 // Pass III: compute local average
138 int64_t d3 = 0;
139 for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
140 Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kh * kw));
141 out_vec.store(out + d3);
142 }
143 for (; d3 < size; d3++) {
144 out[d3] = out[d3] / kh / kw;
145 }
146
147 // move on to next output index
148 data_index_step(n, nbatch, oh, output_height, ow, output_width);
149 }
150 });
151
152 if (!output_.is_contiguous(memory_format)) {
153 output_.copy_(output);
154 }
155 }
156
157 template <typename scalar_t>
158 typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool2d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)159 cpu_adaptive_avg_pool2d_channels_last(
160 Tensor& output_,
161 const Tensor& input_,
162 IntArrayRef output_size) {
163 auto memory_format = at::MemoryFormat::ChannelsLast;
164 auto input = input_.contiguous(memory_format);
165 auto output = output_.contiguous(memory_format);
166
167 auto input_data = input.const_data_ptr<scalar_t>();
168 auto output_data = output.data_ptr<scalar_t>();
169
170 int64_t nbatch = input.size(0);
171 int64_t channels = input.size(1);
172 int64_t input_height = input.size(2);
173 int64_t input_width = input.size(3);
174 int64_t output_height = output_size[0];
175 int64_t output_width = output_size[1];
176
177 using bVec = vec::Vectorized<scalar_t>;
178 using fVec = vec::Vectorized<float>;
179 // parallel on dim N, H, W
180 at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
181 int64_t n = 0;
182 int64_t oh = 0;
183 int64_t ow = 0;
184 data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
185
186 // temp buffer for sum, use float as accumulation type
187 // can't reuse output buffer to store sum since it is BFloat16/Half
188 auto sum_arr = std::make_unique<float []>(channels);
189 float* sum = sum_arr.get();
190
191 for (const auto i : c10::irange(begin, end)) {
192 int64_t ih0 = start_index(oh, output_height, input_height);
193 int64_t ih1 = end_index(oh, output_height, input_height);
194 int64_t kh = ih1 - ih0;
195
196 int64_t iw0 = start_index(ow, output_width, input_width);
197 int64_t iw1 = end_index(ow, output_width, input_width);
198 int64_t kw = iw1 - iw0;
199
200 scalar_t* out = output_data + i * channels;
201 int64_t size = channels;
202
203 // Pass I: zero the out lane
204 int64_t d1 = 0;
205 for (; d1 < size - (size % fVec::size()); d1 += fVec::size()) {
206 fVec sum_fvec = fVec(float(0));
207 sum_fvec.store(sum + d1);
208 }
209 for (; d1 < size; d1++) {
210 sum[d1] = float(0);
211 }
212 // Pass II: compute local sum
213 for (const auto ih : c10::irange(ih0, ih1)) {
214 for (const auto iw : c10::irange(iw0, iw1)) {
215 const scalar_t* in = input_data + n * input_height * input_width * channels +
216 ih * input_width * channels + iw * channels;
217
218 int64_t d2 = 0;
219 for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) {
220 bVec data_bvec = bVec::loadu(in + d2);
221 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
222
223 fVec sum_fvec0 = fVec::loadu(sum + d2) + data_fvec0;
224 fVec sum_fvec1 = fVec::loadu(sum + d2 + fVec::size()) + data_fvec1;
225 sum_fvec0.store(sum + d2);
226 sum_fvec1.store(sum + d2 + fVec::size());
227 }
228 for (; d2 < size; d2++) {
229 sum[d2] += float(in[d2]);
230 }
231 }
232 }
233 // Pass III: compute local average
234 int64_t d3 = 0;
235 for (; d3 < size - (size % bVec::size()); d3 += bVec::size()) {
236 fVec out_fvec0 = fVec::loadu(sum + d3) / fVec(float(kh * kw));
237 fVec out_fvec1 = fVec::loadu(sum + d3 + fVec::size()) / fVec(float(kh * kw));
238
239 bVec out_bvec = convert_from_float<scalar_t>(out_fvec0, out_fvec1);
240 out_bvec.store(out + d3);
241 }
242 for (; d3 < size; d3++) {
243 out[d3] = scalar_t(sum[d3] / kh / kw);
244 }
245
246 // move on to next output index
247 data_index_step(n, nbatch, oh, output_height, ow, output_width);
248 }
249 });
250
251 if (!output_.is_contiguous(memory_format)) {
252 output_.copy_(output);
253 }
254 }
255
256 template <typename scalar_t>
cpu_adaptive_avg_pool2d_backward(Tensor & grad_input_,const Tensor & grad_output_)257 void cpu_adaptive_avg_pool2d_backward(
258 Tensor& grad_input_,
259 const Tensor& grad_output_) {
260 auto grad_output = grad_output_.contiguous();
261 auto grad_input = grad_input_.contiguous();
262
263 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
264 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
265
266 int64_t ndim = grad_output.ndimension();
267 // treat batch size and channels as one dimension
268 int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
269 int64_t input_height = grad_input.size(-2);
270 int64_t input_width = grad_input.size(-1);
271 int64_t output_height = grad_output.size(-2);
272 int64_t output_width = grad_output.size(-1);
273
274 // parallel on dim of N, C
275 at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
276 for (const auto c : c10::irange(begin, end)) {
277 scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width;
278 const scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width;
279
280 for (const auto oh : c10::irange(output_height)) {
281 int64_t ih0 = start_index(oh, output_height, input_height);
282 int64_t ih1 = end_index(oh, output_height, input_height);
283 int64_t kh = ih1 - ih0;
284
285 for (const auto ow : c10::irange(output_width)) {
286 int64_t iw0 = start_index(ow, output_width, input_width);
287 int64_t iw1 = end_index(ow, output_width, input_width);
288 int64_t kw = iw1 - iw0;
289
290 scalar_t grad_delta = grad_output_ptr[oh * output_width + ow] / kh / kw;
291 for (const auto ih : c10::irange(ih0, ih1)) {
292 for (const auto iw : c10::irange(iw0, iw1)) {
293 grad_input_ptr[ih * input_width + iw] += grad_delta;
294 }
295 }
296 }
297 }
298 }
299 });
300
301 if (!grad_input_.is_contiguous()) {
302 grad_input_.copy_(grad_input);
303 }
304 }
305
306 template <typename scalar_t>
cpu_adaptive_avg_pool2d_backward_channels_last(Tensor & grad_input_,const Tensor & grad_output_)307 void cpu_adaptive_avg_pool2d_backward_channels_last(
308 Tensor& grad_input_,
309 const Tensor& grad_output_) {
310 auto memory_format = at::MemoryFormat::ChannelsLast;
311 auto grad_input = grad_input_.contiguous(memory_format);
312 auto grad_output = grad_output_.contiguous(memory_format);
313
314 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
315 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
316
317 int64_t nbatch = grad_input.size(0);
318 int64_t channels = grad_input.size(1);
319 int64_t input_height = grad_input.size(2);
320 int64_t input_width = grad_input.size(3);
321 int64_t output_height = grad_output.size(2);
322 int64_t output_width = grad_output.size(3);
323
324 using Vec = vec::Vectorized<scalar_t>;
325 // parallel on dim N
326 at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
327 for (const auto n : c10::irange(begin, end)) {
328 scalar_t* grad_input_ptr = grad_input_data + n * input_height * input_width * channels;
329 const scalar_t* grad_output_ptr = grad_output_data + n * output_height * output_width * channels;
330
331 for (const auto oh : c10::irange(output_height)) {
332 int64_t ih0 = start_index(oh, output_height, input_height);
333 int64_t ih1 = end_index(oh, output_height, input_height);
334 int64_t kh = ih1 - ih0;
335
336 for (const auto ow : c10::irange(output_width)) {
337 int64_t iw0 = start_index(ow, output_width, input_width);
338 int64_t iw1 = end_index(ow, output_width, input_width);
339 int64_t kw = iw1 - iw0;
340
341 const scalar_t* gout = grad_output_ptr + oh * output_width * channels + ow * channels;
342 int64_t size = channels;
343 for (const auto ih : c10::irange(ih0, ih1)) {
344 for (const auto iw : c10::irange(iw0, iw1)) {
345 scalar_t* gin = grad_input_ptr + ih * input_width * channels + iw * channels;
346
347 int64_t d = 0;
348 for (; d < size - (size % Vec::size()); d += Vec::size()) {
349 Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kh * kw));
350 gin_vec.store(gin + d);
351 }
352 for (; d < size; d++) {
353 gin[d] += gout[d] / kh / kw;
354 }
355 }
356 }
357 }
358 }
359 }
360 });
361
362 if (!grad_input_.is_contiguous(memory_format)) {
363 grad_input_.copy_(grad_input);
364 }
365 }
366
adaptive_avg_pool2d_kernel_impl(Tensor & output,const Tensor & input,IntArrayRef output_size)367 void adaptive_avg_pool2d_kernel_impl(
368 Tensor& output,
369 const Tensor& input,
370 IntArrayRef output_size) {
371 switch (input.suggest_memory_format()) {
372 case at::MemoryFormat::Contiguous: {
373 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool2d", [&] {
374 using param_t = at::opmath_type<scalar_t>;
375 cpu_adaptive_avg_pool2d<scalar_t, /*accscalar_t*/param_t>(output, input, output_size);
376 });
377 break;
378 }
379 case at::MemoryFormat::ChannelsLast: {
380 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{
381 cpu_adaptive_avg_pool2d_channels_last<scalar_t>(output, input, output_size);
382 });
383 break;
384 }
385 default:
386 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
387 }
388 }
389
adapative_avg_pool2d_backward_kernel_impl(Tensor & grad_input,const Tensor & grad_output)390 void adapative_avg_pool2d_backward_kernel_impl(
391 Tensor& grad_input,
392 const Tensor& grad_output) {
393 switch (grad_output.suggest_memory_format()) {
394 case at::MemoryFormat::Contiguous: {
395 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool2d_backward", [&] {
396 cpu_adaptive_avg_pool2d_backward<scalar_t>(grad_input, grad_output);
397 });
398 break;
399 }
400 case at::MemoryFormat::ChannelsLast: {
401 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool2d_backward_channels_last", [&]{
402 cpu_adaptive_avg_pool2d_backward_channels_last<scalar_t>(grad_input, grad_output);
403 });
404 break;
405 }
406 default:
407 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
408 }
409 }
410
411
412 template <typename scalar_t, typename accscalar_t>
cpu_adaptive_avg_pool3d(Tensor & output_,const Tensor & input_,IntArrayRef output_size)413 void cpu_adaptive_avg_pool3d(
414 Tensor& output_,
415 const Tensor& input_,
416 IntArrayRef output_size) {
417 auto input = input_.contiguous();
418 auto output = output_.contiguous();
419
420 auto input_data = input.data_ptr<scalar_t>();
421 auto output_data = output.data_ptr<scalar_t>();
422
423 int64_t ndim = input.ndimension();
424 // treat batch size and channels as one dimension
425 int64_t channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
426 int64_t input_depth = input.size(-3);
427 int64_t input_height = input.size(-2);
428 int64_t input_width = input.size(-1);
429 int64_t output_depth = output_size[0];
430 int64_t output_height = output_size[1];
431 int64_t output_width = output_size[2];
432
433 // parallel on dim of N, C
434 at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
435 for (const auto c : c10::irange(begin, end)) {
436 scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
437 scalar_t* output_ptr = output_data + c * output_depth * output_height * output_width;
438
439 for (const auto od : c10::irange(output_depth)) {
440 int64_t id0 = start_index(od, output_depth, input_depth);
441 int64_t id1 = end_index(od, output_depth, input_depth);
442 int64_t kd = id1 - id0;
443
444 for (const auto oh : c10::irange(output_height)) {
445 int64_t ih0 = start_index(oh, output_height, input_height);
446 int64_t ih1 = end_index(oh, output_height, input_height);
447 int64_t kh = ih1 - ih0;
448
449 for (const auto ow : c10::irange(output_width)) {
450 int64_t iw0 = start_index(ow, output_width, input_width);
451 int64_t iw1 = end_index(ow, output_width, input_width);
452 int64_t kw = iw1 - iw0;
453
454 // compute local average
455 accscalar_t sum = 0;
456 for (const auto id : c10::irange(id0, id1)) {
457 for (const auto ih : c10::irange(ih0, ih1)) {
458 for (const auto iw : c10::irange(iw0, iw1)) {
459 sum += accscalar_t(input_ptr[id * input_height * input_width + ih * input_width + iw]);
460 }
461 }
462 }
463 output_ptr[od * output_height * output_width + oh * output_width + ow] = scalar_t(sum / kd / kh / kw);
464 }
465 }
466 }
467 }
468 });
469
470 if (!output_.is_contiguous()) {
471 output_.copy_(output);
472 }
473 }
474
475
476 template <typename scalar_t>
477 typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool3d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)478 cpu_adaptive_avg_pool3d_channels_last(
479 Tensor& output_,
480 const Tensor& input_,
481 IntArrayRef output_size) {
482 auto memory_format = at::MemoryFormat::ChannelsLast3d;
483 auto input = input_.contiguous(memory_format);
484 auto output = output_.contiguous(memory_format);
485
486 auto input_data = input.data_ptr<scalar_t>();
487 auto output_data = output.data_ptr<scalar_t>();
488
489 int64_t nbatch = input.size(0);
490 int64_t channels = input.size(1);
491 int64_t input_depth = input.size(2);
492 int64_t input_height = input.size(3);
493 int64_t input_width = input.size(4);
494 int64_t output_depth = output_size[0];
495 int64_t output_height = output_size[1];
496 int64_t output_width = output_size[2];
497
498 using Vec = vec::Vectorized<scalar_t>;
499 // parallel on dim N, H, W
500 at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
501 int64_t n = 0;
502 int64_t od = 0;
503 int64_t oh = 0;
504 int64_t ow = 0;
505 data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
506
507 for (const auto i : c10::irange(begin, end)) {
508 int64_t id0 = start_index(od, output_depth, input_depth);
509 int64_t id1 = end_index(od, output_depth, input_depth);
510 int64_t kd = id1 - id0;
511
512 int64_t ih0 = start_index(oh, output_height, input_height);
513 int64_t ih1 = end_index(oh, output_height, input_height);
514 int64_t kh = ih1 - ih0;
515
516 int64_t iw0 = start_index(ow, output_width, input_width);
517 int64_t iw1 = end_index(ow, output_width, input_width);
518 int64_t kw = iw1 - iw0;
519
520 scalar_t* out = output_data + i * channels;
521 int64_t size = channels;
522
523 // Note: For oridinary usage scenario, each out lane should
524 // fit in L1 cache; otherwise consider block dim C.
525 // Pass I: zero the out lane
526 int64_t d1 = 0;
527 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
528 Vec out_vec = Vec(scalar_t(0));
529 out_vec.store(out + d1);
530 }
531 for (; d1 < size; d1++) {
532 out[d1] = scalar_t(0);
533 }
534 // Pass II: compute local sum
535 for (const auto id : c10::irange(id0, id1)) {
536 for (const auto ih : c10::irange(ih0, ih1)) {
537 for (const auto iw : c10::irange(iw0, iw1)) {
538 scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
539 id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
540
541 int64_t d2 = 0;
542 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
543 Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2);
544 out_vec.store(out + d2);
545 }
546 for (; d2 < size; d2++) {
547 out[d2] += in[d2];
548 }
549 }
550 }
551 }
552 // Pass III: compute local average
553 int64_t d3 = 0;
554 for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
555 Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kd * kh * kw));
556 out_vec.store(out + d3);
557 }
558 for (; d3 < size; d3++) {
559 out[d3] = out[d3] / kd / kh / kw;
560 }
561
562 // move on to next output index
563 data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
564 }
565 });
566
567 if (!output_.is_contiguous(memory_format)) {
568 output_.copy_(output);
569 }
570 }
571
572 template <typename scalar_t>
573 typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool3d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)574 cpu_adaptive_avg_pool3d_channels_last(
575 Tensor& output_,
576 const Tensor& input_,
577 IntArrayRef output_size) {
578 auto memory_format = at::MemoryFormat::ChannelsLast3d;
579 auto input = input_.contiguous(memory_format);
580 auto output = output_.contiguous(memory_format);
581
582 auto input_data = input.data_ptr<scalar_t>();
583 auto output_data = output.data_ptr<scalar_t>();
584
585 int64_t nbatch = input.size(0);
586 int64_t channels = input.size(1);
587 int64_t input_depth = input.size(2);
588 int64_t input_height = input.size(3);
589 int64_t input_width = input.size(4);
590 int64_t output_depth = output_size[0];
591 int64_t output_height = output_size[1];
592 int64_t output_width = output_size[2];
593
594 using bVec = vec::Vectorized<scalar_t>;
595 using fVec = vec::Vectorized<float>;
596 // parallel on dim N,D, H, W
597 at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
598 int64_t n = 0;
599 int64_t oh = 0;
600 int64_t ow = 0;
601 int64_t od = 0;
602 data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
603
604 // temp buffer for sum, use float as accumulation type
605 // can't reuse output buffer to store sum since it is BFloat16/Half
606 auto sum_arr = std::make_unique<float []>(channels);
607 float* sum = sum_arr.get();
608
609 for (const auto i : c10::irange(begin, end)) {
610 int64_t id0 = start_index(od, output_depth, input_depth);
611 int64_t id1 = end_index(od, output_depth, input_depth);
612 int64_t kd = id1 - id0;
613
614 int64_t ih0 = start_index(oh, output_height, input_height);
615 int64_t ih1 = end_index(oh, output_height, input_height);
616 int64_t kh = ih1 - ih0;
617
618 int64_t iw0 = start_index(ow, output_width, input_width);
619 int64_t iw1 = end_index(ow, output_width, input_width);
620 int64_t kw = iw1 - iw0;
621
622 scalar_t* out = output_data + i * channels;
623 int64_t size = channels;
624
625 // Pass I: zero the out lane
626 int64_t d1 = 0;
627 for (; d1 < size - (size % fVec::size()); d1 += fVec::size()) {
628 fVec sum_fvec = fVec(float(0));
629 sum_fvec.store(sum + d1);
630 }
631 for (; d1 < size; d1++) {
632 sum[d1] = float(0);
633 }
634 // Pass II: compute local sum
635 for (const auto id : c10::irange(id0, id1)) {
636 for (const auto ih : c10::irange(ih0, ih1)) {
637 for (const auto iw : c10::irange(iw0, iw1)) {
638 scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
639 id * input_height * input_width * channels +
640 ih * input_width * channels + iw * channels;
641
642 int64_t d2 = 0;
643 for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) {
644 bVec data_bvec = bVec::loadu(in + d2);
645 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
646
647 fVec sum_fvec0 = fVec::loadu(sum + d2) + data_fvec0;
648 fVec sum_fvec1 = fVec::loadu(sum + d2 + fVec::size()) + data_fvec1;
649 sum_fvec0.store(sum + d2);
650 sum_fvec1.store(sum + d2 + fVec::size());
651 }
652 for (; d2 < size; d2++) {
653 sum[d2] += float(in[d2]);
654 }
655 }
656 }
657 }
658 // Pass III: compute local average
659 int64_t d3 = 0;
660 for (; d3 < size - (size % bVec::size()); d3 += bVec::size()) {
661 fVec out_fvec0 = fVec::loadu(sum + d3) / fVec(float(kd * kh * kw));
662 fVec out_fvec1 = fVec::loadu(sum + d3 + fVec::size()) / fVec(float(kd * kh * kw));
663
664 bVec out_bvec = convert_from_float<scalar_t>(out_fvec0, out_fvec1);
665 out_bvec.store(out + d3);
666 }
667 for (; d3 < size; d3++) {
668 out[d3] = scalar_t(sum[d3] / kd / kh / kw);
669 }
670
671 // move on to next output index
672 data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
673 }
674 });
675
676 if (!output_.is_contiguous(memory_format)) {
677 output_.copy_(output);
678 }
679 }
680
681 template <typename scalar_t>
cpu_adaptive_avg_pool3d_backward(Tensor & grad_input_,const Tensor & grad_output_)682 void cpu_adaptive_avg_pool3d_backward(
683 Tensor& grad_input_,
684 const Tensor& grad_output_) {
685 auto grad_output = grad_output_.contiguous();
686 auto grad_input = grad_input_.contiguous();
687
688 auto grad_output_data = grad_output.data_ptr<scalar_t>();
689 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
690
691 int64_t ndim = grad_output.ndimension();
692 // treat batch size and channels as one dimension
693 int64_t channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
694 int64_t input_depth = grad_input.size(-3);
695 int64_t input_height = grad_input.size(-2);
696 int64_t input_width = grad_input.size(-1);
697 int64_t output_depth = grad_output.size(-3);
698 int64_t output_height = grad_output.size(-2);
699 int64_t output_width = grad_output.size(-1);
700
701 // parallel on dim of N, C
702 at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
703 for (const auto c : c10::irange(begin, end)) {
704 scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
705 scalar_t* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
706
707 for (const auto od : c10::irange(output_depth)) {
708 int64_t id0 = start_index(od, output_depth, input_depth);
709 int64_t id1 = end_index(od, output_depth, input_depth);
710 int64_t kd = id1 - id0;
711 for (const auto oh : c10::irange(output_height)) {
712 int64_t ih0 = start_index(oh, output_height, input_height);
713 int64_t ih1 = end_index(oh, output_height, input_height);
714 int64_t kh = ih1 - ih0;
715
716 for (const auto ow : c10::irange(output_width)) {
717 int64_t iw0 = start_index(ow, output_width, input_width);
718 int64_t iw1 = end_index(ow, output_width, input_width);
719 int64_t kw = iw1 - iw0;
720
721 scalar_t grad_delta = grad_output_ptr[od * output_width * output_height + oh * output_width + ow] / kd / kh / kw;
722 for (const auto id : c10::irange(id0, id1)) {
723 for (const auto ih : c10::irange(ih0, ih1)) {
724 for (const auto iw : c10::irange(iw0, iw1)) {
725 grad_input_ptr[id * input_height * input_width + ih * input_width + iw] += grad_delta;
726 }
727 }
728 }
729 }
730 }
731 }
732 }
733 });
734
735 if (!grad_input_.is_contiguous()) {
736 grad_input_.copy_(grad_input);
737 }
738 }
739
740 template <typename scalar_t>
cpu_adaptive_avg_pool3d_backward_channels_last(Tensor & grad_input_,const Tensor & grad_output_)741 void cpu_adaptive_avg_pool3d_backward_channels_last(
742 Tensor& grad_input_,
743 const Tensor& grad_output_) {
744 auto memory_format = at::MemoryFormat::ChannelsLast3d;
745 auto grad_input = grad_input_.contiguous(memory_format);
746 auto grad_output = grad_output_.contiguous(memory_format);
747
748 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
749 auto grad_output_data = grad_output.data_ptr<scalar_t>();
750
751 int64_t nbatch = grad_input.size(0);
752 int64_t channels = grad_input.size(1);
753 int64_t input_depth = grad_input.size(2);
754 int64_t input_height = grad_input.size(3);
755 int64_t input_width = grad_input.size(4);
756 int64_t output_depth = grad_output.size(2);
757 int64_t output_height = grad_output.size(3);
758 int64_t output_width = grad_output.size(4);
759
760 using Vec = vec::Vectorized<scalar_t>;
761 // parallel on dim N
762 at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
763 for (const auto n : c10::irange(begin, end)) {
764 scalar_t* grad_input_ptr = grad_input_data + n * input_depth * input_height * input_width * channels;
765 scalar_t* grad_output_ptr = grad_output_data + n * output_depth * output_height * output_width * channels;
766
767 for (const auto od : c10::irange(output_depth)) {
768 int64_t id0 = start_index(od, output_depth, input_depth);
769 int64_t id1 = end_index(od, output_depth, input_depth);
770 int64_t kd = id1 - id0;
771 for (const auto oh : c10::irange(output_height)) {
772 int64_t ih0 = start_index(oh, output_height, input_height);
773 int64_t ih1 = end_index(oh, output_height, input_height);
774 int64_t kh = ih1 - ih0;
775
776 for (const auto ow : c10::irange(output_width)) {
777 int64_t iw0 = start_index(ow, output_width, input_width);
778 int64_t iw1 = end_index(ow, output_width, input_width);
779 int64_t kw = iw1 - iw0;
780
781 scalar_t* gout = grad_output_ptr + od * output_depth * channels + oh * output_width * channels + ow * channels;
782 int64_t size = channels;
783 for (const auto id : c10::irange(id0, id1)) {
784 for (const auto ih : c10::irange(ih0, ih1)) {
785 for (const auto iw : c10::irange(iw0, iw1)) {
786 scalar_t* gin = grad_input_ptr + id * input_width * input_height * channels + ih * input_width * channels + iw * channels;
787
788 int64_t d = 0;
789 for (; d < size - (size % Vec::size()); d += Vec::size()) {
790 Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kd * kh * kw));
791 gin_vec.store(gin + d);
792 }
793 for (; d < size; d++) {
794 gin[d] += gout[d] / kd / kh / kw;
795 }
796 }
797 }
798 }
799 }
800 }
801 }
802 }
803 });
804
805 if (!grad_input_.is_contiguous(memory_format)) {
806 grad_input_.copy_(grad_input);
807 }
808 }
809
810
adaptive_avg_pool3d_kernel_impl(Tensor & output,const Tensor & input,IntArrayRef output_size)811 void adaptive_avg_pool3d_kernel_impl(
812 Tensor& output,
813 const Tensor& input,
814 IntArrayRef output_size) {
815 switch (input.suggest_memory_format()) {
816 case at::MemoryFormat::Contiguous: {
817 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool3d", [&] {
818 using param_t = at::opmath_type<scalar_t>;
819 cpu_adaptive_avg_pool3d<scalar_t, /*accscalar_t*/param_t>(output, input, output_size);
820 });
821 break;
822 }
823 case at::MemoryFormat::ChannelsLast3d: {
824 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool3d_channels_last", [&]{
825 cpu_adaptive_avg_pool3d_channels_last<scalar_t>(output, input, output_size);
826 });
827 break;
828 }
829 default:
830 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
831 }
832 }
833
adapative_avg_pool3d_backward_kernel_impl(Tensor & grad_input,const Tensor & grad_output)834 void adapative_avg_pool3d_backward_kernel_impl(
835 Tensor& grad_input,
836 const Tensor& grad_output) {
837 switch (grad_output.suggest_memory_format()) {
838 case at::MemoryFormat::Contiguous: {
839 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool3d_backward", [&] {
840 cpu_adaptive_avg_pool3d_backward<scalar_t>(grad_input, grad_output);
841 });
842 break;
843 }
844 case at::MemoryFormat::ChannelsLast3d: {
845 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool3d_backward_channels_last", [&]{
846 cpu_adaptive_avg_pool3d_backward_channels_last<scalar_t>(grad_input, grad_output);
847 });
848 break;
849 }
850 default:
851 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
852 }
853 }
854
855 } // anonymous namespace
856
857 REGISTER_DISPATCH(adaptive_avg_pool2d_kernel, &adaptive_avg_pool2d_kernel_impl);
858 REGISTER_DISPATCH(adaptive_avg_pool2d_backward_kernel, &adapative_avg_pool2d_backward_kernel_impl);
859 REGISTER_DISPATCH(adaptive_avg_pool3d_kernel, &adaptive_avg_pool3d_kernel_impl);
860 REGISTER_DISPATCH(adaptive_avg_pool3d_backward_kernel, &adapative_avg_pool3d_backward_kernel_impl);
861
862 } // at::native
863