xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qtanh.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/native/quantized/cpu/QuantizedOps.h>
5 #include <ATen/native/quantized/cpu/init_qnnpack.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <c10/util/irange.h>
8 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_empty_affine_quantized.h>
15 #include <ATen/ops/tanh_native.h>
16 #endif
17 
18 namespace at {
19 namespace native {
20 
21 DEFINE_DISPATCH(qtanh_stub);
22 
23 #ifdef USE_PYTORCH_QNNPACK
24 // This ALWAYS outputs scale=2.0/256, zp=128, dtype=quint8
qnnpack_tanh(Tensor input)25 static Tensor qnnpack_tanh(Tensor input) {
26   TORCH_CHECK(input.ndimension() > 0, "qnnpack_tanh(): Got empty input tensor");
27   TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
28                "qnnpack_tanh(): Expected input data type ",
29                toString(c10::kQUInt8),
30                " but got ",
31                toString(input.scalar_type()));
32   Tensor qy;
33   constexpr float output_scale = 2.0f / 256.0f;
34   constexpr int32_t output_zero_point = 128;
35 
36   initQNNPACK();
37 
38   Tensor input_contig = input.contiguous(input.suggest_memory_format());
39   size_t num_elems = 1;
40   for (const auto i : c10::irange(1, input_contig.ndimension())) {
41     num_elems *= input_contig.size(i);
42   }
43   const auto zero_point = input_contig.q_zero_point();
44   const auto scale = input_contig.q_scale();
45 
46   pytorch_qnnp_operator_t tanh_op{nullptr};
47   const pytorch_qnnp_status createStatus = pytorch_qnnp_create_tanh_nc_q8(
48     num_elems /* channels */,
49     zero_point /* input zero point */,
50     scale /* input scale */,
51     output_zero_point /* output zero point */,
52     output_scale /* output scale */,
53     std::numeric_limits<uint8_t>::min() /* output min */,
54     std::numeric_limits<uint8_t>::max() /* output max */,
55     0 /* flags */,
56     &tanh_op);
57 
58   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
59       qnnpack_uniq_ptr(tanh_op);
60 
61   TORCH_INTERNAL_ASSERT(createStatus == pytorch_qnnp_status_success,
62                         "failed to create QNNPACK TanH operator");
63   qy = at::_empty_affine_quantized(
64     input_contig.sizes(),
65     at::device(kCPU).dtype(input_contig.dtype()),
66     output_scale,
67     output_zero_point,
68     input_contig.suggest_memory_format());
69 
70   const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_tanh_nc_q8(
71     tanh_op,
72     input_contig.size(0) /* batch size */,
73     (uint8_t*)input_contig.data_ptr<c10::quint8>() /* input data */,
74     num_elems /* input stride */,
75     (uint8_t*)qy.data_ptr<c10::quint8>() /* output data */,
76     num_elems /* output stride */);
77   TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
78                         "failed to setup QNNPACK TanH operator");
79 
80   pthreadpool_t threadpool = caffe2::pthreadpool_();
81 
82   const pytorch_qnnp_status runStatus =
83     pytorch_qnnp_run_operator(tanh_op, threadpool);
84 
85   TORCH_INTERNAL_ASSERT(
86     runStatus == pytorch_qnnp_status_success,
87     "failed to run QNNPACK TanH operator");
88   return qy;
89 }
90 #endif  // USE_PYTORCH_QNNPACK
91 
tanh_quantized_cpu(const Tensor & qx)92 Tensor tanh_quantized_cpu(const Tensor& qx) {
93 #ifdef USE_PYTORCH_QNNPACK
94   if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
95       qx.scalar_type() == kQUInt8) {
96     return qnnpack_tanh(qx);
97   }
98 #endif  // USE_PYTORCH_QNNPACK
99   Tensor qy;
100   qtanh_stub(qx.device().type(), qx, qy);
101   return qy;
102 }
103 }}  // namespace at::native
104