xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qhardswish.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <torch/library.h>
5 #include <ATen/native/quantized/cpu/QuantizedOps.h>
6 #include <ATen/native/quantized/cpu/init_qnnpack.h>
7 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
8 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/_empty_affine_quantized.h>
14 #endif
15 
16 #include <algorithm>
17 
18 namespace at {
19 namespace native {
20 
21 DEFINE_DISPATCH(qhardswish_stub);
22 
23 namespace {
24 
25 #ifdef USE_PYTORCH_QNNPACK
qnnpack_hardswish(const Tensor & qx,Tensor & qy)26 Tensor qnnpack_hardswish(const Tensor& qx, Tensor& qy) {
27   TORCH_CHECK(qx.ndimension() > 0, "qnnpack_hardswish(): Got empty input tensor");
28   TORCH_CHECK(qx.scalar_type() == c10::kQUInt8,
29                 "qnnpack_hardswish(): Expected input data type to be ",
30                 toString(c10::kQUInt8),
31                 " but got ",
32                 toString(qx.scalar_type()));
33   initQNNPACK();
34 
35   size_t num_elems = qx.numel() / qx.size(0);
36   const auto i_zero_point = qx.q_zero_point();
37   const auto i_scale = qx.q_scale();
38   const auto o_zero_point = qy.q_zero_point();
39   const auto o_scale = qy.q_scale();
40 
41   pytorch_qnnp_operator_t hardswish_op{nullptr};
42   const pytorch_qnnp_status createStatus = pytorch_qnnp_create_hardswish_nc_q8(
43     num_elems, // channels
44     i_zero_point,
45     i_scale,
46     o_zero_point,
47     o_scale,
48     std::numeric_limits<uint8_t>::min(), // output min
49     std::numeric_limits<uint8_t>::max(), // output max
50     0, // flags
51     &hardswish_op);
52 
53   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
54       qnnpack_uniq_ptr(hardswish_op);
55 
56   TORCH_INTERNAL_ASSERT(createStatus == pytorch_qnnp_status_success,
57                         "failed to create QNNPACK Hardswish operator");
58 
59   const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_hardswish_nc_q8(
60     hardswish_op,
61     qx.size(0), // batch size
62     (uint8_t*)qx.data_ptr<c10::quint8>(), // input data
63     num_elems, // input stride
64     (uint8_t*)qy.data_ptr<c10::quint8>(), // output data
65     num_elems); // output stride
66   TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
67                         "failed to setup QNNPACK Hardswish operator");
68 
69   pthreadpool_t threadpool = caffe2::pthreadpool_();
70 
71   const pytorch_qnnp_status runStatus =
72     pytorch_qnnp_run_operator(hardswish_op, threadpool);
73 
74   TORCH_INTERNAL_ASSERT(
75     runStatus == pytorch_qnnp_status_success,
76     "failed to run QNNPACK Hardswish operator");
77   return qy;
78 }
79 #endif // USE_PYTORCH_QNNPACK
80 
81 } // namespace
82 
quantized_hardswish(const Tensor & qx,double output_scale,int64_t output_zero_point)83 static Tensor quantized_hardswish(const Tensor& qx, double output_scale, int64_t output_zero_point) {
84   Tensor qy = at::_empty_affine_quantized(
85       qx.sizes(),
86       at::device(kCPU).dtype(qx.scalar_type()),
87       output_scale,
88       output_zero_point,
89       qx.suggest_memory_format());
90 #ifdef USE_PYTORCH_QNNPACK
91   if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
92       qx.scalar_type() == kQUInt8) {
93     Tensor qx_contig = qx.contiguous(qx.suggest_memory_format());
94     qnnpack_hardswish(qx_contig, qy);
95     return qy;
96   }
97 #endif  // USE_PYTORCH_QNNPACK
98   qhardswish_stub(qx.device().type(), qx, qy);
99   return qy;
100 }
101 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)102 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
103   m.impl(TORCH_SELECTIVE_NAME("quantized::hardswish"), TORCH_FN(quantized_hardswish));
104 }
105 
106 }}  // namespace at::native
107