1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <torch/library.h>
5 #include <ATen/native/quantized/cpu/QuantizedOps.h>
6 #include <ATen/native/quantized/cpu/init_qnnpack.h>
7 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
8 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/_empty_affine_quantized.h>
14 #endif
15
16 #include <algorithm>
17
18 namespace at {
19 namespace native {
20
21 DEFINE_DISPATCH(qhardswish_stub);
22
23 namespace {
24
25 #ifdef USE_PYTORCH_QNNPACK
qnnpack_hardswish(const Tensor & qx,Tensor & qy)26 Tensor qnnpack_hardswish(const Tensor& qx, Tensor& qy) {
27 TORCH_CHECK(qx.ndimension() > 0, "qnnpack_hardswish(): Got empty input tensor");
28 TORCH_CHECK(qx.scalar_type() == c10::kQUInt8,
29 "qnnpack_hardswish(): Expected input data type to be ",
30 toString(c10::kQUInt8),
31 " but got ",
32 toString(qx.scalar_type()));
33 initQNNPACK();
34
35 size_t num_elems = qx.numel() / qx.size(0);
36 const auto i_zero_point = qx.q_zero_point();
37 const auto i_scale = qx.q_scale();
38 const auto o_zero_point = qy.q_zero_point();
39 const auto o_scale = qy.q_scale();
40
41 pytorch_qnnp_operator_t hardswish_op{nullptr};
42 const pytorch_qnnp_status createStatus = pytorch_qnnp_create_hardswish_nc_q8(
43 num_elems, // channels
44 i_zero_point,
45 i_scale,
46 o_zero_point,
47 o_scale,
48 std::numeric_limits<uint8_t>::min(), // output min
49 std::numeric_limits<uint8_t>::max(), // output max
50 0, // flags
51 &hardswish_op);
52
53 std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
54 qnnpack_uniq_ptr(hardswish_op);
55
56 TORCH_INTERNAL_ASSERT(createStatus == pytorch_qnnp_status_success,
57 "failed to create QNNPACK Hardswish operator");
58
59 const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_hardswish_nc_q8(
60 hardswish_op,
61 qx.size(0), // batch size
62 (uint8_t*)qx.data_ptr<c10::quint8>(), // input data
63 num_elems, // input stride
64 (uint8_t*)qy.data_ptr<c10::quint8>(), // output data
65 num_elems); // output stride
66 TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
67 "failed to setup QNNPACK Hardswish operator");
68
69 pthreadpool_t threadpool = caffe2::pthreadpool_();
70
71 const pytorch_qnnp_status runStatus =
72 pytorch_qnnp_run_operator(hardswish_op, threadpool);
73
74 TORCH_INTERNAL_ASSERT(
75 runStatus == pytorch_qnnp_status_success,
76 "failed to run QNNPACK Hardswish operator");
77 return qy;
78 }
79 #endif // USE_PYTORCH_QNNPACK
80
81 } // namespace
82
quantized_hardswish(const Tensor & qx,double output_scale,int64_t output_zero_point)83 static Tensor quantized_hardswish(const Tensor& qx, double output_scale, int64_t output_zero_point) {
84 Tensor qy = at::_empty_affine_quantized(
85 qx.sizes(),
86 at::device(kCPU).dtype(qx.scalar_type()),
87 output_scale,
88 output_zero_point,
89 qx.suggest_memory_format());
90 #ifdef USE_PYTORCH_QNNPACK
91 if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
92 qx.scalar_type() == kQUInt8) {
93 Tensor qx_contig = qx.contiguous(qx.suggest_memory_format());
94 qnnpack_hardswish(qx_contig, qy);
95 return qy;
96 }
97 #endif // USE_PYTORCH_QNNPACK
98 qhardswish_stub(qx.device().type(), qx, qy);
99 return qy;
100 }
101
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)102 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
103 m.impl(TORCH_SELECTIVE_NAME("quantized::hardswish"), TORCH_FN(quantized_hardswish));
104 }
105
106 }} // namespace at::native
107