xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/AveragePool2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/native/Pool.h>
7 #include <ATen/native/quantized/cpu/init_qnnpack.h>
8 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
9 #include <ATen/native/quantized/cpu/QuantizedOps.h>
10 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_empty_affine_quantized.h>
17 #include <ATen/ops/avg_pool2d_native.h>
18 #endif
19 
20 #include <c10/util/irange.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <limits>
25 #include <vector>
26 
27 namespace at {
28 namespace native {
29 
30 DEFINE_DISPATCH(qavg_pool2d_nhwc_stub);
31 
32 namespace {
33 
34 template <typename scalar_t>
avg_pool2d_out_frame(const Tensor & input,Tensor & output,int64_t nInputPlane,int64_t inputWidth,int64_t inputHeight,int64_t outputWidth,int64_t outputHeight,int kW,int kH,int dW,int dH,int padW,int padH,bool count_include_pad,std::optional<int64_t> divisor_override)35 static void avg_pool2d_out_frame(
36     const Tensor& input,
37     Tensor& output,
38     int64_t nInputPlane,
39     int64_t inputWidth,
40     int64_t inputHeight,
41     int64_t outputWidth,
42     int64_t outputHeight,
43     int kW,
44     int kH,
45     int dW,
46     int dH,
47     int padW,
48     int padH,
49     bool count_include_pad,
50     std::optional<int64_t> divisor_override) {
51   Tensor input_contig = input.contiguous();
52   auto input_data = input_contig.data_ptr<scalar_t>();
53   auto output_data = output.data_ptr<scalar_t>();
54   const auto scale_factor = input.q_scale() / output.q_scale();
55   const auto input_zero_point = input.q_zero_point();
56   const auto output_zero_point = output.q_zero_point();
57 
58   at::parallel_for(0, nInputPlane, 0, [&](int64_t start, int64_t end) {
59     for (const auto k : c10::irange(start, end)) {
60       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
61       int64_t xx, yy;
62       /* For all output pixels... */
63       scalar_t* ptr_output = output_data + k * outputWidth * outputHeight;
64       const scalar_t* ptr_input = input_data + k * inputWidth * inputHeight;
65       auto minimum =
66           std::numeric_limits<typename scalar_t::underlying>::lowest();
67       auto maximum = std::numeric_limits<typename scalar_t::underlying>::max();
68 
69       for (yy = 0; yy < outputHeight; yy++) {
70         for (xx = 0; xx < outputWidth; xx++) {
71           /* Compute the mean of the input image... */
72           int64_t hstart = yy * dH - padH;
73           int64_t wstart = xx * dW - padW;
74           int64_t hend = std::min(hstart + kH, inputHeight + padH);
75           int64_t wend = std::min(wstart + kW, inputWidth + padW);
76           int64_t pool_size = (hend - hstart) * (wend - wstart);
77           hstart = std::max(hstart, (int64_t)0);
78           wstart = std::max(wstart, (int64_t)0);
79           hend = std::min(hend, inputHeight);
80           wend = std::min(wend, inputWidth);
81 
82           int sum_int = 0;
83           ptr_output->val_ = 0;
84 
85           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
86           int64_t divide_factor;
87           int64_t size = (hend - hstart) * (wend - wstart);
88           if (divisor_override.has_value()) {
89             divide_factor = divisor_override.value();
90           } else {
91             if (count_include_pad) {
92               divide_factor = pool_size;
93             } else {
94               divide_factor = (hend - hstart) * (wend - wstart);
95             }
96           }
97 
98           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
99           int64_t kx, ky;
100           for (ky = hstart; ky < hend; ky++) {
101             for (kx = wstart; kx < wend; kx++)
102               sum_int += (ptr_input + ky * inputWidth + kx)->val_;
103           }
104           // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
105           float multiplier = scale_factor / divide_factor;
106 
107           // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
108           sum_int -= size * input_zero_point;
109           // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
110           float sum = sum_int * 1.0;
111           /* Update output by requantizing the result */
112           ptr_output->val_ =
113               static_cast<typename scalar_t::underlying>(std::min<int32_t>(
114                   std::max<int32_t>(
115                       std::nearbyint(sum * multiplier + output_zero_point),
116                       minimum),
117                   maximum));
118           ptr_output++;
119         }
120       }
121     }
122   });
123 }
124 
get_kernel(IntArrayRef kernel_size)125 inline std::pair<int, int> get_kernel(IntArrayRef kernel_size) {
126   TORCH_CHECK(
127       kernel_size.size() == 1 || kernel_size.size() == 2,
128       "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
129   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
130   const int kW = kernel_size.size() == 1
131       ? kH
132       : safe_downcast<int, int64_t>(kernel_size[1]);
133   return std::make_pair(kW, kH);
134 }
135 
get_stride(IntArrayRef stride,int kW,int kH)136 inline std::pair<int, int> get_stride(IntArrayRef stride, int kW, int kH) {
137   TORCH_CHECK(
138       stride.empty() || stride.size() == 1 || stride.size() == 2,
139       "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
140   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
141   const int dW = stride.empty()
142       ? kW
143       : stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
144   return std::make_pair(dW, dH);
145 }
146 
get_padding(IntArrayRef padding)147 inline std::pair<int, int> get_padding(IntArrayRef padding) {
148   TORCH_CHECK(
149       padding.size() == 1 || padding.size() == 2,
150       "avg_pool2d: padding must either be a single int, or a tuple of two ints");
151   const int padH = safe_downcast<int, int64_t>(padding[0]);
152   const int padW =
153       padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
154   return std::make_pair(padW, padH);
155 }
156 
get_output_shape(const Tensor & input_,int kW,int kH,int dW,int dH,int padW,int padH,bool ceil_mode)157 std::vector<int64_t> get_output_shape(
158     const Tensor& input_,
159     int kW,
160     int kH,
161     int dW,
162     int dH,
163     int padW,
164     int padH,
165     bool ceil_mode) {
166   const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
167   const int64_t nInputPlane = input_.size(-3);
168   const int64_t inputHeight = input_.size(-2);
169   const int64_t inputWidth = input_.size(-1);
170   const int64_t outputHeight =
171       pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
172   const int64_t outputWidth =
173       pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
174   if (input_.ndimension() == 3) {
175     return {nInputPlane, outputHeight, outputWidth};
176   }
177   return {nbatch, nInputPlane, outputHeight, outputWidth};
178 }
179 
180 template <typename scalar_t>
q_avg_pool2d(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,bool ceil_mode,bool count_include_pad,std::optional<int64_t> divisor_override)181 Tensor q_avg_pool2d(
182     const Tensor& input,
183     IntArrayRef kernel_size,
184     IntArrayRef stride,
185     IntArrayRef padding,
186     bool ceil_mode,
187     bool count_include_pad,
188     std::optional<int64_t> divisor_override) {
189   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
190   auto [kW, kH] = get_kernel(kernel_size);
191   auto [dW, dH] = get_stride(stride, kW, kH);
192   auto [padW, padH] = get_padding(padding);
193 
194   const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
195   const int64_t nInputPlane = input.size(-3);
196   const int64_t inputHeight = input.size(-2);
197   const int64_t inputWidth = input.size(-1);
198 
199   TORCH_CHECK(
200       !divisor_override.has_value() || divisor_override.value() != 0,
201       "divisor must be not zero");
202 
203   auto output_shape =
204       get_output_shape(input, kW, kH, dW, dH, padW, padH, ceil_mode);
205   const int64_t outputHeight = output_shape[output_shape.size() - 2];
206   const int64_t outputWidth = output_shape[output_shape.size() - 1];
207   if (input.is_contiguous(c10::MemoryFormat::ChannelsLast)) {
208     auto output = at::_empty_affine_quantized(
209         output_shape,
210         input.options().memory_format(input.suggest_memory_format()),
211         input.q_scale(),
212         input.q_zero_point(),
213         std::nullopt);
214     // fast path for channel last: qavg_pool_2d_nhwc_stub
215     qavg_pool2d_nhwc_stub(
216         input.device().type(),
217         input,
218         output,
219         nbatch,
220         nInputPlane,
221         inputWidth,
222         inputHeight,
223         outputWidth,
224         outputHeight,
225         kW,
226         kH,
227         dW,
228         dH,
229         padW,
230         padH,
231         count_include_pad,
232         divisor_override);
233     return output;
234   } else {
235     auto output = at::_empty_affine_quantized(
236         output_shape, input.options(), input.q_scale(), input.q_zero_point());
237     avg_pool2d_out_frame<scalar_t>(
238         input,
239         output,
240         // Contract batch and channels into one dimension
241         nbatch * nInputPlane,
242         inputWidth,
243         inputHeight,
244         outputWidth,
245         outputHeight,
246         kW,
247         kH,
248         dW,
249         dH,
250         padW,
251         padH,
252         count_include_pad,
253         divisor_override);
254     return output;
255   }
256 }
257 } // namespace
258 
259 #ifdef USE_PYTORCH_QNNPACK
260 namespace qnnp_avgpool_helper {
qnnpack_avg_pool2d(Tensor input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,bool ceil_mode,bool count_include_pad,std::optional<int64_t> divisor_override)261 Tensor qnnpack_avg_pool2d(
262     Tensor input,
263     IntArrayRef kernel_size,
264     IntArrayRef stride,
265     IntArrayRef padding,
266     bool ceil_mode,
267     bool count_include_pad,
268     std::optional<int64_t> divisor_override) {
269   auto [kW, kH] = get_kernel(kernel_size);
270   auto [dW, dH] = get_stride(stride, kW, kH);
271   auto [padW, padH] = get_padding(padding);
272   TORCH_CHECK(
273       input.ndimension() == 4,
274       "qnnpack_avg_pool2d(): Expected input to be 4-dimensional: got ",
275       input.ndimension());
276   TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
277                 "qnnpack_avg_pool2d(): Expected input data type ",
278                 toString(c10::kQUInt8),
279                 " but got ",
280                 toString(input.scalar_type()));
281 
282   int64_t batch_size = input.size(0);
283   int64_t inC = input.size(1);
284   int64_t inH = input.size(2);
285   int64_t inW = input.size(3);
286   auto output_shape =
287       get_output_shape(input, kW, kH, dW, dH, padW, padH, ceil_mode);
288   const int64_t oH = output_shape[output_shape.size() - 2];
289   const int64_t oW = output_shape[output_shape.size() - 1];
290   const auto outC = inC;
291 
292   Tensor input_contig = input.contiguous(c10::MemoryFormat::ChannelsLast);
293 
294   initQNNPACK();
295   const auto scale = input_contig.q_scale();
296   const auto zero_point = input_contig.q_zero_point();
297 
298   TORCH_CHECK(
299       oH > 0 && oW > 0,
300       "qnnpack_avg_pool2d(): the resulting output Tensor size should be >= 0");
301   // NHWC output
302   auto output = at::_empty_affine_quantized(
303       output_shape,
304       at::device(kCPU).dtype(kQUInt8),
305       scale,
306       zero_point,
307       c10::MemoryFormat::ChannelsLast);
308 
309   pytorch_qnnp_operator_t qnnpack_operator{nullptr};
310   const pytorch_qnnp_status createStatus =
311       pytorch_qnnp_create_average_pooling2d_nhwc_q8(
312           padH /* input_padding_height */,
313           padW /* input_padding_width */,
314           kH /* kernel height */,
315           kW /* kernel width */,
316           dH /* stride height */,
317           dW /* stride width */,
318           inC /* input channels */,
319           zero_point /* input zero_point */,
320           scale /* input scale */,
321           zero_point /* output zero_point */,
322           scale /* output scale */,
323           std::numeric_limits<uint8_t>::min() /* output min */,
324           std::numeric_limits<uint8_t>::max() /* output max */,
325           0 /* flags */,
326           &qnnpack_operator);
327   CAFFE_ENFORCE(
328       createStatus == pytorch_qnnp_status_success,
329       "failed to create QNNPACK Average Pooling operator");
330   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
331       qnnpack_uniq_ptr(qnnpack_operator);
332 
333   const pytorch_qnnp_status setupStatus =
334       pytorch_qnnp_setup_average_pooling2d_nhwc_q8(
335           qnnpack_operator,
336           batch_size,
337           inH,
338           inW,
339           (uint8_t*)input_contig.data_ptr<c10::quint8>() /* input data */,
340           inC,
341           (uint8_t*)output.data_ptr<c10::quint8>() /* output data */,
342           outC,
343           nullptr /* thread pool */);
344   CAFFE_ENFORCE(
345       setupStatus == pytorch_qnnp_status_success,
346       "failed to setup QNNPACK Average Pooling operator");
347   pthreadpool_t threadpool = caffe2::pthreadpool_();
348   const pytorch_qnnp_status runStatus =
349       pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
350   TORCH_INTERNAL_ASSERT(
351       runStatus == pytorch_qnnp_status_success,
352       "failed to run QNNPACK Average Pool operator");
353   return output.contiguous(input.suggest_memory_format());
354 }
355 } // qnnp_avgpool_helper
356 #endif
357 
avg_pool2d_quantized_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,bool ceil_mode,bool count_include_pad,std::optional<int64_t> divisor_override)358 Tensor avg_pool2d_quantized_cpu(
359     const Tensor& input,
360     IntArrayRef kernel_size,
361     IntArrayRef stride,
362     IntArrayRef padding,
363     bool ceil_mode,
364     bool count_include_pad,
365     std::optional<int64_t> divisor_override) {
366   Tensor output;
367 #ifdef USE_PYTORCH_QNNPACK
368   if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
369       input.scalar_type() == kQUInt8 && !ceil_mode) {
370     return at::native::qnnp_avgpool_helper::qnnpack_avg_pool2d(
371         input,
372         kernel_size,
373         stride,
374         padding,
375         ceil_mode,
376         count_include_pad,
377         divisor_override);
378   }
379 #endif
380   AT_DISPATCH_QINT_TYPES(input.scalar_type(), "avg_pool2d_quantized_cpu", [&]() {
381     output = q_avg_pool2d<scalar_t>(
382         input,
383         kernel_size,
384         stride,
385         padding,
386         ceil_mode,
387         count_include_pad,
388         divisor_override);
389   });
390   return output;
391 }
392 
393 } // namespace native
394 } // namespace at
395