xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/AdaptivePooling.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/cpu/vec/functional.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11 #include <ATen/OpMathType.h>
12 
13 namespace at::native {
14 
15 namespace {
16 
17 template <typename scalar_t, typename accscalar_t>
cpu_adaptive_avg_pool2d(Tensor & output_,const Tensor & input_,IntArrayRef output_size)18 void cpu_adaptive_avg_pool2d(
19     Tensor& output_,
20     const Tensor& input_,
21     IntArrayRef output_size) {
22   auto input = input_.contiguous();
23   auto output = output_.contiguous();
24 
25   auto input_data = input.const_data_ptr<scalar_t>();
26   auto output_data = output.data_ptr<scalar_t>();
27 
28   int64_t ndim = input.ndimension();
29   // treat batch size and channels as one dimension
30   int64_t channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
31   int64_t input_height = input.size(-2);
32   int64_t input_width = input.size(-1);
33   int64_t output_height = output_size[0];
34   int64_t output_width = output_size[1];
35 
36   // parallel on dim of N, C
37   at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
38     for (const auto c : c10::irange(begin, end)) {
39       const scalar_t* input_ptr = input_data + c * input_height * input_width;
40       scalar_t* output_ptr = output_data + c * output_height * output_width;
41 
42       for (const auto oh : c10::irange(output_height)) {
43         int64_t ih0 = start_index(oh, output_height, input_height);
44         int64_t ih1 = end_index(oh, output_height, input_height);
45         int64_t kh = ih1 - ih0;
46 
47         for (const auto ow : c10::irange(output_width)) {
48           int64_t iw0 = start_index(ow, output_width, input_width);
49           int64_t iw1 = end_index(ow, output_width, input_width);
50           int64_t kw = iw1 - iw0;
51 
52           // compute local average
53           accscalar_t sum = 0;
54           for (const auto ih : c10::irange(ih0, ih1)) {
55             for (const auto iw : c10::irange(iw0, iw1)) {
56               sum += accscalar_t(input_ptr[ih * input_width + iw]);
57             }
58           }
59           output_ptr[oh * output_width + ow] = scalar_t(sum / kh / kw);
60         }
61       }
62     }
63   });
64 
65   if (!output_.is_contiguous()) {
66     output_.copy_(output);
67   }
68 }
69 
70 template <typename scalar_t>
71 typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool2d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)72 cpu_adaptive_avg_pool2d_channels_last(
73     Tensor& output_,
74     const Tensor& input_,
75     IntArrayRef output_size) {
76   auto memory_format = at::MemoryFormat::ChannelsLast;
77   auto input = input_.contiguous(memory_format);
78   auto output = output_.contiguous(memory_format);
79 
80   auto input_data = input.const_data_ptr<scalar_t>();
81   auto output_data = output.data_ptr<scalar_t>();
82 
83   int64_t nbatch = input.size(0);
84   int64_t channels = input.size(1);
85   int64_t input_height = input.size(2);
86   int64_t input_width = input.size(3);
87   int64_t output_height = output_size[0];
88   int64_t output_width = output_size[1];
89 
90   using Vec = vec::Vectorized<scalar_t>;
91   // parallel on dim N, H, W
92   at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
93     int64_t n = 0;
94     int64_t oh = 0;
95     int64_t ow = 0;
96     data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
97 
98     for (const auto i : c10::irange(begin, end)) {
99       int64_t ih0 = start_index(oh, output_height, input_height);
100       int64_t ih1 = end_index(oh, output_height, input_height);
101       int64_t kh = ih1 - ih0;
102 
103       int64_t iw0 = start_index(ow, output_width, input_width);
104       int64_t iw1 = end_index(ow, output_width, input_width);
105       int64_t kw = iw1 - iw0;
106 
107       scalar_t* out = output_data + i * channels;
108       int64_t size = channels;
109 
110       // Note: For ordinary usage scenario, each out lane should
111       //   fit in L1 cache; otherwise consider block dim C.
112       // Pass I: zero the out lane
113       int64_t d1 = 0;
114       for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
115         Vec out_vec = Vec(scalar_t(0));
116         out_vec.store(out + d1);
117       }
118       for (; d1 < size; d1++) {
119         out[d1] = scalar_t(0);
120       }
121       // Pass II: compute local sum
122       for (const auto ih : c10::irange(ih0, ih1)) {
123         for (const auto iw : c10::irange(iw0, iw1)) {
124           const scalar_t* in = input_data + n * input_height * input_width * channels +
125               ih * input_width * channels + iw * channels;
126 
127           int64_t d2 = 0;
128           for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
129             Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2);
130             out_vec.store(out + d2);
131           }
132           for (; d2 < size; d2++) {
133             out[d2] += in[d2];
134           }
135         }
136       }
137       // Pass III: compute local average
138       int64_t d3 = 0;
139       for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
140         Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kh * kw));
141         out_vec.store(out + d3);
142       }
143       for (; d3 < size; d3++) {
144         out[d3] = out[d3] / kh / kw;
145       }
146 
147       // move on to next output index
148       data_index_step(n, nbatch, oh, output_height, ow, output_width);
149     }
150   });
151 
152   if (!output_.is_contiguous(memory_format)) {
153     output_.copy_(output);
154   }
155 }
156 
157 template <typename scalar_t>
158 typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool2d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)159 cpu_adaptive_avg_pool2d_channels_last(
160     Tensor& output_,
161     const Tensor& input_,
162     IntArrayRef output_size) {
163   auto memory_format = at::MemoryFormat::ChannelsLast;
164   auto input = input_.contiguous(memory_format);
165   auto output = output_.contiguous(memory_format);
166 
167   auto input_data = input.const_data_ptr<scalar_t>();
168   auto output_data = output.data_ptr<scalar_t>();
169 
170   int64_t nbatch = input.size(0);
171   int64_t channels = input.size(1);
172   int64_t input_height = input.size(2);
173   int64_t input_width = input.size(3);
174   int64_t output_height = output_size[0];
175   int64_t output_width = output_size[1];
176 
177   using bVec = vec::Vectorized<scalar_t>;
178   using fVec = vec::Vectorized<float>;
179   // parallel on dim N, H, W
180   at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
181     int64_t n = 0;
182     int64_t oh = 0;
183     int64_t ow = 0;
184     data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
185 
186     // temp buffer for sum, use float as accumulation type
187     // can't reuse output buffer to store sum since it is BFloat16/Half
188     auto sum_arr = std::make_unique<float []>(channels);
189     float* sum = sum_arr.get();
190 
191     for (const auto i : c10::irange(begin, end)) {
192       int64_t ih0 = start_index(oh, output_height, input_height);
193       int64_t ih1 = end_index(oh, output_height, input_height);
194       int64_t kh = ih1 - ih0;
195 
196       int64_t iw0 = start_index(ow, output_width, input_width);
197       int64_t iw1 = end_index(ow, output_width, input_width);
198       int64_t kw = iw1 - iw0;
199 
200       scalar_t* out = output_data + i * channels;
201       int64_t size = channels;
202 
203       // Pass I: zero the out lane
204       int64_t d1 = 0;
205       for (; d1 < size - (size % fVec::size()); d1 += fVec::size()) {
206         fVec sum_fvec = fVec(float(0));
207         sum_fvec.store(sum + d1);
208       }
209       for (; d1 < size; d1++) {
210         sum[d1] = float(0);
211       }
212       // Pass II: compute local sum
213       for (const auto ih : c10::irange(ih0, ih1)) {
214         for (const auto iw : c10::irange(iw0, iw1)) {
215           const scalar_t* in = input_data + n * input_height * input_width * channels +
216               ih * input_width * channels + iw * channels;
217 
218           int64_t d2 = 0;
219           for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) {
220             bVec data_bvec = bVec::loadu(in + d2);
221             auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
222 
223             fVec sum_fvec0 = fVec::loadu(sum + d2) + data_fvec0;
224             fVec sum_fvec1 = fVec::loadu(sum + d2 + fVec::size()) + data_fvec1;
225             sum_fvec0.store(sum + d2);
226             sum_fvec1.store(sum + d2 + fVec::size());
227           }
228           for (; d2 < size; d2++) {
229             sum[d2] += float(in[d2]);
230           }
231         }
232       }
233       // Pass III: compute local average
234       int64_t d3 = 0;
235       for (; d3 < size - (size % bVec::size()); d3 += bVec::size()) {
236         fVec out_fvec0 = fVec::loadu(sum + d3) / fVec(float(kh * kw));
237         fVec out_fvec1 = fVec::loadu(sum + d3 + fVec::size()) / fVec(float(kh * kw));
238 
239         bVec out_bvec = convert_from_float<scalar_t>(out_fvec0, out_fvec1);
240         out_bvec.store(out + d3);
241       }
242       for (; d3 < size; d3++) {
243         out[d3] = scalar_t(sum[d3] / kh / kw);
244       }
245 
246       // move on to next output index
247       data_index_step(n, nbatch, oh, output_height, ow, output_width);
248     }
249   });
250 
251   if (!output_.is_contiguous(memory_format)) {
252     output_.copy_(output);
253   }
254 }
255 
256 template <typename scalar_t>
cpu_adaptive_avg_pool2d_backward(Tensor & grad_input_,const Tensor & grad_output_)257 void cpu_adaptive_avg_pool2d_backward(
258     Tensor& grad_input_,
259     const Tensor& grad_output_) {
260   auto grad_output = grad_output_.contiguous();
261   auto grad_input = grad_input_.contiguous();
262 
263   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
264   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
265 
266   int64_t ndim = grad_output.ndimension();
267   // treat batch size and channels as one dimension
268   int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
269   int64_t input_height = grad_input.size(-2);
270   int64_t input_width = grad_input.size(-1);
271   int64_t output_height = grad_output.size(-2);
272   int64_t output_width = grad_output.size(-1);
273 
274   // parallel on dim of N, C
275   at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
276     for (const auto c : c10::irange(begin, end)) {
277       scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width;
278       const scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width;
279 
280       for (const auto oh : c10::irange(output_height)) {
281         int64_t ih0 = start_index(oh, output_height, input_height);
282         int64_t ih1 = end_index(oh, output_height, input_height);
283         int64_t kh = ih1 - ih0;
284 
285         for (const auto ow : c10::irange(output_width)) {
286           int64_t iw0 = start_index(ow, output_width, input_width);
287           int64_t iw1 = end_index(ow, output_width, input_width);
288           int64_t kw = iw1 - iw0;
289 
290           scalar_t grad_delta = grad_output_ptr[oh * output_width + ow] / kh / kw;
291           for (const auto ih : c10::irange(ih0, ih1)) {
292             for (const auto iw : c10::irange(iw0, iw1)) {
293               grad_input_ptr[ih * input_width + iw] += grad_delta;
294             }
295           }
296         }
297       }
298     }
299   });
300 
301   if (!grad_input_.is_contiguous()) {
302     grad_input_.copy_(grad_input);
303   }
304 }
305 
306 template <typename scalar_t>
cpu_adaptive_avg_pool2d_backward_channels_last(Tensor & grad_input_,const Tensor & grad_output_)307 void cpu_adaptive_avg_pool2d_backward_channels_last(
308     Tensor& grad_input_,
309     const Tensor& grad_output_) {
310   auto memory_format = at::MemoryFormat::ChannelsLast;
311   auto grad_input = grad_input_.contiguous(memory_format);
312   auto grad_output = grad_output_.contiguous(memory_format);
313 
314   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
315   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
316 
317   int64_t nbatch = grad_input.size(0);
318   int64_t channels = grad_input.size(1);
319   int64_t input_height = grad_input.size(2);
320   int64_t input_width = grad_input.size(3);
321   int64_t output_height = grad_output.size(2);
322   int64_t output_width = grad_output.size(3);
323 
324   using Vec = vec::Vectorized<scalar_t>;
325   // parallel on dim N
326   at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
327     for (const auto n : c10::irange(begin, end)) {
328       scalar_t* grad_input_ptr = grad_input_data + n * input_height * input_width * channels;
329       const scalar_t* grad_output_ptr = grad_output_data + n * output_height * output_width * channels;
330 
331       for (const auto oh : c10::irange(output_height)) {
332         int64_t ih0 = start_index(oh, output_height, input_height);
333         int64_t ih1 = end_index(oh, output_height, input_height);
334         int64_t kh = ih1 - ih0;
335 
336         for (const auto ow : c10::irange(output_width)) {
337           int64_t iw0 = start_index(ow, output_width, input_width);
338           int64_t iw1 = end_index(ow, output_width, input_width);
339           int64_t kw = iw1 - iw0;
340 
341           const scalar_t* gout = grad_output_ptr + oh * output_width * channels + ow * channels;
342           int64_t size = channels;
343           for (const auto ih : c10::irange(ih0, ih1)) {
344             for (const auto iw : c10::irange(iw0, iw1)) {
345               scalar_t* gin = grad_input_ptr + ih * input_width * channels + iw * channels;
346 
347               int64_t d = 0;
348               for (; d < size - (size % Vec::size()); d += Vec::size()) {
349                 Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kh * kw));
350                 gin_vec.store(gin + d);
351               }
352               for (; d < size; d++) {
353                 gin[d] += gout[d] / kh / kw;
354               }
355             }
356           }
357         }
358       }
359     }
360   });
361 
362   if (!grad_input_.is_contiguous(memory_format)) {
363     grad_input_.copy_(grad_input);
364   }
365 }
366 
adaptive_avg_pool2d_kernel_impl(Tensor & output,const Tensor & input,IntArrayRef output_size)367 void adaptive_avg_pool2d_kernel_impl(
368     Tensor& output,
369     const Tensor& input,
370     IntArrayRef output_size) {
371   switch (input.suggest_memory_format()) {
372     case at::MemoryFormat::Contiguous: {
373       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool2d", [&] {
374         using param_t = at::opmath_type<scalar_t>;
375         cpu_adaptive_avg_pool2d<scalar_t, /*accscalar_t*/param_t>(output, input, output_size);
376       });
377       break;
378     }
379     case at::MemoryFormat::ChannelsLast: {
380       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{
381         cpu_adaptive_avg_pool2d_channels_last<scalar_t>(output, input, output_size);
382       });
383       break;
384     }
385     default:
386       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
387   }
388 }
389 
adapative_avg_pool2d_backward_kernel_impl(Tensor & grad_input,const Tensor & grad_output)390 void adapative_avg_pool2d_backward_kernel_impl(
391     Tensor& grad_input,
392     const Tensor& grad_output) {
393   switch (grad_output.suggest_memory_format()) {
394     case at::MemoryFormat::Contiguous: {
395       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool2d_backward", [&] {
396         cpu_adaptive_avg_pool2d_backward<scalar_t>(grad_input, grad_output);
397       });
398       break;
399     }
400     case at::MemoryFormat::ChannelsLast: {
401       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool2d_backward_channels_last", [&]{
402         cpu_adaptive_avg_pool2d_backward_channels_last<scalar_t>(grad_input, grad_output);
403       });
404       break;
405     }
406     default:
407       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
408   }
409 }
410 
411 
412 template <typename scalar_t, typename accscalar_t>
cpu_adaptive_avg_pool3d(Tensor & output_,const Tensor & input_,IntArrayRef output_size)413 void cpu_adaptive_avg_pool3d(
414     Tensor& output_,
415     const Tensor& input_,
416     IntArrayRef output_size) {
417   auto input = input_.contiguous();
418   auto output = output_.contiguous();
419 
420   auto input_data = input.data_ptr<scalar_t>();
421   auto output_data = output.data_ptr<scalar_t>();
422 
423   int64_t ndim = input.ndimension();
424   // treat batch size and channels as one dimension
425   int64_t channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
426   int64_t input_depth = input.size(-3);
427   int64_t input_height = input.size(-2);
428   int64_t input_width = input.size(-1);
429   int64_t output_depth = output_size[0];
430   int64_t output_height = output_size[1];
431   int64_t output_width = output_size[2];
432 
433   // parallel on dim of N, C
434   at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
435     for (const auto c : c10::irange(begin, end)) {
436       scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
437       scalar_t* output_ptr = output_data + c * output_depth * output_height * output_width;
438 
439       for (const auto od : c10::irange(output_depth)) {
440         int64_t id0 = start_index(od, output_depth, input_depth);
441         int64_t id1 = end_index(od, output_depth, input_depth);
442         int64_t kd = id1 - id0;
443 
444         for (const auto oh : c10::irange(output_height)) {
445           int64_t ih0 = start_index(oh, output_height, input_height);
446           int64_t ih1 = end_index(oh, output_height, input_height);
447           int64_t kh = ih1 - ih0;
448 
449           for (const auto ow : c10::irange(output_width)) {
450             int64_t iw0 = start_index(ow, output_width, input_width);
451             int64_t iw1 = end_index(ow, output_width, input_width);
452             int64_t kw = iw1 - iw0;
453 
454             // compute local average
455             accscalar_t sum = 0;
456             for (const auto id : c10::irange(id0, id1)) {
457               for (const auto ih : c10::irange(ih0, ih1)) {
458                 for (const auto iw : c10::irange(iw0, iw1)) {
459                   sum += accscalar_t(input_ptr[id * input_height * input_width + ih * input_width + iw]);
460                 }
461               }
462             }
463             output_ptr[od * output_height * output_width + oh * output_width + ow] = scalar_t(sum / kd / kh / kw);
464           }
465         }
466       }
467     }
468   });
469 
470   if (!output_.is_contiguous()) {
471     output_.copy_(output);
472   }
473 }
474 
475 
476 template <typename scalar_t>
477 typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool3d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)478 cpu_adaptive_avg_pool3d_channels_last(
479     Tensor& output_,
480     const Tensor& input_,
481     IntArrayRef output_size) {
482   auto memory_format = at::MemoryFormat::ChannelsLast3d;
483   auto input = input_.contiguous(memory_format);
484   auto output = output_.contiguous(memory_format);
485 
486   auto input_data = input.data_ptr<scalar_t>();
487   auto output_data = output.data_ptr<scalar_t>();
488 
489   int64_t nbatch = input.size(0);
490   int64_t channels = input.size(1);
491   int64_t input_depth = input.size(2);
492   int64_t input_height = input.size(3);
493   int64_t input_width = input.size(4);
494   int64_t output_depth = output_size[0];
495   int64_t output_height = output_size[1];
496   int64_t output_width = output_size[2];
497 
498   using Vec = vec::Vectorized<scalar_t>;
499   // parallel on dim N, H, W
500   at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
501     int64_t n = 0;
502     int64_t od = 0;
503     int64_t oh = 0;
504     int64_t ow = 0;
505     data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
506 
507     for (const auto i : c10::irange(begin, end)) {
508       int64_t id0 = start_index(od, output_depth, input_depth);
509       int64_t id1 = end_index(od, output_depth, input_depth);
510       int64_t kd = id1 - id0;
511 
512       int64_t ih0 = start_index(oh, output_height, input_height);
513       int64_t ih1 = end_index(oh, output_height, input_height);
514       int64_t kh = ih1 - ih0;
515 
516       int64_t iw0 = start_index(ow, output_width, input_width);
517       int64_t iw1 = end_index(ow, output_width, input_width);
518       int64_t kw = iw1 - iw0;
519 
520       scalar_t* out = output_data + i * channels;
521       int64_t size = channels;
522 
523       // Note: For oridinary usage scenario, each out lane should
524       //   fit in L1 cache; otherwise consider block dim C.
525       // Pass I: zero the out lane
526       int64_t d1 = 0;
527       for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
528         Vec out_vec = Vec(scalar_t(0));
529         out_vec.store(out + d1);
530       }
531       for (; d1 < size; d1++) {
532         out[d1] = scalar_t(0);
533       }
534       // Pass II: compute local sum
535       for (const auto id : c10::irange(id0, id1)) {
536         for (const auto ih : c10::irange(ih0, ih1)) {
537           for (const auto iw : c10::irange(iw0, iw1)) {
538             scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
539                 id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
540 
541             int64_t d2 = 0;
542             for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
543               Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2);
544               out_vec.store(out + d2);
545             }
546             for (; d2 < size; d2++) {
547               out[d2] += in[d2];
548             }
549           }
550         }
551       }
552       // Pass III: compute local average
553       int64_t d3 = 0;
554       for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
555         Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kd * kh * kw));
556         out_vec.store(out + d3);
557       }
558       for (; d3 < size; d3++) {
559         out[d3] = out[d3] / kd / kh / kw;
560       }
561 
562       // move on to next output index
563       data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
564     }
565   });
566 
567   if (!output_.is_contiguous(memory_format)) {
568     output_.copy_(output);
569   }
570 }
571 
572 template <typename scalar_t>
573 typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool3d_channels_last(Tensor & output_,const Tensor & input_,IntArrayRef output_size)574 cpu_adaptive_avg_pool3d_channels_last(
575     Tensor& output_,
576     const Tensor& input_,
577     IntArrayRef output_size) {
578   auto memory_format = at::MemoryFormat::ChannelsLast3d;
579   auto input = input_.contiguous(memory_format);
580   auto output = output_.contiguous(memory_format);
581 
582   auto input_data = input.data_ptr<scalar_t>();
583   auto output_data = output.data_ptr<scalar_t>();
584 
585   int64_t nbatch = input.size(0);
586   int64_t channels = input.size(1);
587   int64_t input_depth = input.size(2);
588   int64_t input_height = input.size(3);
589   int64_t input_width = input.size(4);
590   int64_t output_depth = output_size[0];
591   int64_t output_height = output_size[1];
592   int64_t output_width = output_size[2];
593 
594   using bVec = vec::Vectorized<scalar_t>;
595   using fVec = vec::Vectorized<float>;
596   // parallel on dim N,D, H, W
597   at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
598     int64_t n = 0;
599     int64_t oh = 0;
600     int64_t ow = 0;
601     int64_t od = 0;
602     data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
603 
604     // temp buffer for sum, use float as accumulation type
605     // can't reuse output buffer to store sum since it is BFloat16/Half
606     auto sum_arr = std::make_unique<float []>(channels);
607     float* sum = sum_arr.get();
608 
609     for (const auto i : c10::irange(begin, end)) {
610       int64_t id0 = start_index(od, output_depth, input_depth);
611       int64_t id1 = end_index(od, output_depth, input_depth);
612       int64_t kd = id1 - id0;
613 
614       int64_t ih0 = start_index(oh, output_height, input_height);
615       int64_t ih1 = end_index(oh, output_height, input_height);
616       int64_t kh = ih1 - ih0;
617 
618       int64_t iw0 = start_index(ow, output_width, input_width);
619       int64_t iw1 = end_index(ow, output_width, input_width);
620       int64_t kw = iw1 - iw0;
621 
622       scalar_t* out = output_data + i * channels;
623       int64_t size = channels;
624 
625       // Pass I: zero the out lane
626       int64_t d1 = 0;
627       for (; d1 < size - (size % fVec::size()); d1 += fVec::size()) {
628         fVec sum_fvec = fVec(float(0));
629         sum_fvec.store(sum + d1);
630       }
631       for (; d1 < size; d1++) {
632         sum[d1] = float(0);
633       }
634       // Pass II: compute local sum
635       for (const auto id : c10::irange(id0, id1)) {
636         for (const auto ih : c10::irange(ih0, ih1)) {
637             for (const auto iw : c10::irange(iw0, iw1)) {
638                 scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
639                     id * input_height * input_width * channels +
640                     ih * input_width * channels + iw * channels;
641 
642                 int64_t d2 = 0;
643                 for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) {
644                     bVec data_bvec = bVec::loadu(in + d2);
645                     auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
646 
647                     fVec sum_fvec0 = fVec::loadu(sum + d2) + data_fvec0;
648                     fVec sum_fvec1 = fVec::loadu(sum + d2 + fVec::size()) + data_fvec1;
649                     sum_fvec0.store(sum + d2);
650                     sum_fvec1.store(sum + d2 + fVec::size());
651                 }
652                 for (; d2 < size; d2++) {
653                     sum[d2] += float(in[d2]);
654                 }
655             }
656         }
657       }
658       // Pass III: compute local average
659       int64_t d3 = 0;
660       for (; d3 < size - (size % bVec::size()); d3 += bVec::size()) {
661         fVec out_fvec0 = fVec::loadu(sum + d3) / fVec(float(kd * kh * kw));
662         fVec out_fvec1 = fVec::loadu(sum + d3 + fVec::size()) / fVec(float(kd * kh * kw));
663 
664         bVec out_bvec = convert_from_float<scalar_t>(out_fvec0, out_fvec1);
665         out_bvec.store(out + d3);
666       }
667       for (; d3 < size; d3++) {
668         out[d3] = scalar_t(sum[d3] / kd / kh / kw);
669       }
670 
671       // move on to next output index
672       data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
673     }
674   });
675 
676   if (!output_.is_contiguous(memory_format)) {
677     output_.copy_(output);
678   }
679 }
680 
681 template <typename scalar_t>
cpu_adaptive_avg_pool3d_backward(Tensor & grad_input_,const Tensor & grad_output_)682 void cpu_adaptive_avg_pool3d_backward(
683     Tensor& grad_input_,
684     const Tensor& grad_output_) {
685   auto grad_output = grad_output_.contiguous();
686   auto grad_input = grad_input_.contiguous();
687 
688   auto grad_output_data = grad_output.data_ptr<scalar_t>();
689   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
690 
691   int64_t ndim = grad_output.ndimension();
692   // treat batch size and channels as one dimension
693   int64_t channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
694   int64_t input_depth = grad_input.size(-3);
695   int64_t input_height = grad_input.size(-2);
696   int64_t input_width = grad_input.size(-1);
697   int64_t output_depth = grad_output.size(-3);
698   int64_t output_height = grad_output.size(-2);
699   int64_t output_width = grad_output.size(-1);
700 
701   // parallel on dim of N, C
702   at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
703     for (const auto c : c10::irange(begin, end)) {
704       scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
705       scalar_t* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
706 
707       for (const auto od : c10::irange(output_depth)) {
708         int64_t id0 = start_index(od, output_depth, input_depth);
709         int64_t id1 = end_index(od, output_depth, input_depth);
710         int64_t kd = id1 - id0;
711         for (const auto oh : c10::irange(output_height)) {
712           int64_t ih0 = start_index(oh, output_height, input_height);
713           int64_t ih1 = end_index(oh, output_height, input_height);
714           int64_t kh = ih1 - ih0;
715 
716           for (const auto ow : c10::irange(output_width)) {
717             int64_t iw0 = start_index(ow, output_width, input_width);
718             int64_t iw1 = end_index(ow, output_width, input_width);
719             int64_t kw = iw1 - iw0;
720 
721             scalar_t grad_delta = grad_output_ptr[od * output_width * output_height + oh * output_width + ow] / kd / kh / kw;
722             for (const auto id : c10::irange(id0, id1)) {
723               for (const auto ih : c10::irange(ih0, ih1)) {
724                 for (const auto iw : c10::irange(iw0, iw1)) {
725                   grad_input_ptr[id * input_height * input_width + ih * input_width + iw] += grad_delta;
726                 }
727               }
728             }
729           }
730         }
731       }
732     }
733   });
734 
735   if (!grad_input_.is_contiguous()) {
736     grad_input_.copy_(grad_input);
737   }
738 }
739 
740 template <typename scalar_t>
cpu_adaptive_avg_pool3d_backward_channels_last(Tensor & grad_input_,const Tensor & grad_output_)741 void cpu_adaptive_avg_pool3d_backward_channels_last(
742     Tensor& grad_input_,
743     const Tensor& grad_output_) {
744   auto memory_format = at::MemoryFormat::ChannelsLast3d;
745   auto grad_input = grad_input_.contiguous(memory_format);
746   auto grad_output = grad_output_.contiguous(memory_format);
747 
748   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
749   auto grad_output_data = grad_output.data_ptr<scalar_t>();
750 
751   int64_t nbatch = grad_input.size(0);
752   int64_t channels = grad_input.size(1);
753   int64_t input_depth = grad_input.size(2);
754   int64_t input_height = grad_input.size(3);
755   int64_t input_width = grad_input.size(4);
756   int64_t output_depth = grad_output.size(2);
757   int64_t output_height = grad_output.size(3);
758   int64_t output_width = grad_output.size(4);
759 
760   using Vec = vec::Vectorized<scalar_t>;
761   // parallel on dim N
762   at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
763     for (const auto n : c10::irange(begin, end)) {
764       scalar_t* grad_input_ptr = grad_input_data + n * input_depth * input_height * input_width * channels;
765       scalar_t* grad_output_ptr = grad_output_data + n * output_depth * output_height * output_width * channels;
766 
767       for (const auto od : c10::irange(output_depth)) {
768         int64_t id0 = start_index(od, output_depth, input_depth);
769         int64_t id1 = end_index(od, output_depth, input_depth);
770         int64_t kd = id1 - id0;
771         for (const auto oh : c10::irange(output_height)) {
772           int64_t ih0 = start_index(oh, output_height, input_height);
773           int64_t ih1 = end_index(oh, output_height, input_height);
774           int64_t kh = ih1 - ih0;
775 
776           for (const auto ow : c10::irange(output_width)) {
777             int64_t iw0 = start_index(ow, output_width, input_width);
778             int64_t iw1 = end_index(ow, output_width, input_width);
779             int64_t kw = iw1 - iw0;
780 
781             scalar_t* gout = grad_output_ptr + od * output_depth * channels + oh * output_width * channels + ow * channels;
782             int64_t size = channels;
783             for (const auto id : c10::irange(id0, id1)) {
784               for (const auto ih : c10::irange(ih0, ih1)) {
785                 for (const auto iw : c10::irange(iw0, iw1)) {
786                   scalar_t* gin = grad_input_ptr + id * input_width * input_height * channels + ih * input_width * channels + iw * channels;
787 
788                   int64_t d = 0;
789                   for (; d < size - (size % Vec::size()); d += Vec::size()) {
790                     Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kd * kh * kw));
791                     gin_vec.store(gin + d);
792                   }
793                   for (; d < size; d++) {
794                     gin[d] += gout[d] / kd / kh / kw;
795                   }
796                 }
797               }
798             }
799           }
800         }
801       }
802     }
803   });
804 
805   if (!grad_input_.is_contiguous(memory_format)) {
806     grad_input_.copy_(grad_input);
807   }
808 }
809 
810 
adaptive_avg_pool3d_kernel_impl(Tensor & output,const Tensor & input,IntArrayRef output_size)811 void adaptive_avg_pool3d_kernel_impl(
812     Tensor& output,
813     const Tensor& input,
814     IntArrayRef output_size) {
815   switch (input.suggest_memory_format()) {
816     case at::MemoryFormat::Contiguous: {
817       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool3d", [&] {
818         using param_t = at::opmath_type<scalar_t>;
819         cpu_adaptive_avg_pool3d<scalar_t, /*accscalar_t*/param_t>(output, input, output_size);
820       });
821       break;
822     }
823     case at::MemoryFormat::ChannelsLast3d: {
824       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool3d_channels_last", [&]{
825         cpu_adaptive_avg_pool3d_channels_last<scalar_t>(output, input, output_size);
826       });
827       break;
828     }
829     default:
830       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
831   }
832 }
833 
adapative_avg_pool3d_backward_kernel_impl(Tensor & grad_input,const Tensor & grad_output)834 void adapative_avg_pool3d_backward_kernel_impl(
835     Tensor& grad_input,
836     const Tensor& grad_output) {
837   switch (grad_output.suggest_memory_format()) {
838     case at::MemoryFormat::Contiguous: {
839       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool3d_backward", [&] {
840         cpu_adaptive_avg_pool3d_backward<scalar_t>(grad_input, grad_output);
841       });
842       break;
843     }
844     case at::MemoryFormat::ChannelsLast3d: {
845       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool3d_backward_channels_last", [&]{
846         cpu_adaptive_avg_pool3d_backward_channels_last<scalar_t>(grad_input, grad_output);
847       });
848       break;
849     }
850     default:
851       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
852   }
853 }
854 
855 } // anonymous namespace
856 
857 REGISTER_DISPATCH(adaptive_avg_pool2d_kernel, &adaptive_avg_pool2d_kernel_impl);
858 REGISTER_DISPATCH(adaptive_avg_pool2d_backward_kernel, &adapative_avg_pool2d_backward_kernel_impl);
859 REGISTER_DISPATCH(adaptive_avg_pool3d_kernel, &adaptive_avg_pool3d_kernel_impl);
860 REGISTER_DISPATCH(adaptive_avg_pool3d_backward_kernel, &adapative_avg_pool3d_backward_kernel_impl);
861 
862 } // at::native
863