1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/native/quantized/cpu/QuantizedOps.h>
5 #include <ATen/native/quantized/cpu/init_qnnpack.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <c10/util/irange.h>
8 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_empty_affine_quantized.h>
15 #include <ATen/ops/tanh_native.h>
16 #endif
17
18 namespace at {
19 namespace native {
20
21 DEFINE_DISPATCH(qtanh_stub);
22
23 #ifdef USE_PYTORCH_QNNPACK
24 // This ALWAYS outputs scale=2.0/256, zp=128, dtype=quint8
qnnpack_tanh(Tensor input)25 static Tensor qnnpack_tanh(Tensor input) {
26 TORCH_CHECK(input.ndimension() > 0, "qnnpack_tanh(): Got empty input tensor");
27 TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
28 "qnnpack_tanh(): Expected input data type ",
29 toString(c10::kQUInt8),
30 " but got ",
31 toString(input.scalar_type()));
32 Tensor qy;
33 constexpr float output_scale = 2.0f / 256.0f;
34 constexpr int32_t output_zero_point = 128;
35
36 initQNNPACK();
37
38 Tensor input_contig = input.contiguous(input.suggest_memory_format());
39 size_t num_elems = 1;
40 for (const auto i : c10::irange(1, input_contig.ndimension())) {
41 num_elems *= input_contig.size(i);
42 }
43 const auto zero_point = input_contig.q_zero_point();
44 const auto scale = input_contig.q_scale();
45
46 pytorch_qnnp_operator_t tanh_op{nullptr};
47 const pytorch_qnnp_status createStatus = pytorch_qnnp_create_tanh_nc_q8(
48 num_elems /* channels */,
49 zero_point /* input zero point */,
50 scale /* input scale */,
51 output_zero_point /* output zero point */,
52 output_scale /* output scale */,
53 std::numeric_limits<uint8_t>::min() /* output min */,
54 std::numeric_limits<uint8_t>::max() /* output max */,
55 0 /* flags */,
56 &tanh_op);
57
58 std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
59 qnnpack_uniq_ptr(tanh_op);
60
61 TORCH_INTERNAL_ASSERT(createStatus == pytorch_qnnp_status_success,
62 "failed to create QNNPACK TanH operator");
63 qy = at::_empty_affine_quantized(
64 input_contig.sizes(),
65 at::device(kCPU).dtype(input_contig.dtype()),
66 output_scale,
67 output_zero_point,
68 input_contig.suggest_memory_format());
69
70 const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_tanh_nc_q8(
71 tanh_op,
72 input_contig.size(0) /* batch size */,
73 (uint8_t*)input_contig.data_ptr<c10::quint8>() /* input data */,
74 num_elems /* input stride */,
75 (uint8_t*)qy.data_ptr<c10::quint8>() /* output data */,
76 num_elems /* output stride */);
77 TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
78 "failed to setup QNNPACK TanH operator");
79
80 pthreadpool_t threadpool = caffe2::pthreadpool_();
81
82 const pytorch_qnnp_status runStatus =
83 pytorch_qnnp_run_operator(tanh_op, threadpool);
84
85 TORCH_INTERNAL_ASSERT(
86 runStatus == pytorch_qnnp_status_success,
87 "failed to run QNNPACK TanH operator");
88 return qy;
89 }
90 #endif // USE_PYTORCH_QNNPACK
91
tanh_quantized_cpu(const Tensor & qx)92 Tensor tanh_quantized_cpu(const Tensor& qx) {
93 #ifdef USE_PYTORCH_QNNPACK
94 if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
95 qx.scalar_type() == kQUInt8) {
96 return qnnpack_tanh(qx);
97 }
98 #endif // USE_PYTORCH_QNNPACK
99 Tensor qy;
100 qtanh_stub(qx.device().type(), qx, qy);
101 return qy;
102 }
103 }} // namespace at::native
104