xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qmul.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/ExpandUtils.h>
5 #include <torch/library.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/cpu/Loops.h>
8 #include <ATen/native/quantized/cpu/OnednnUtils.h>
9 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
10 #include <ATen/native/quantized/cpu/QuantUtils.h>
11 #include <ATen/native/quantized/cpu/QuantizedOps.h>
12 #include <ATen/native/quantized/cpu/XnnpackUtils.h>
13 #include <ATen/native/quantized/cpu/init_qnnpack.h>
14 #include <ATen/quantized/Quantizer.h>
15 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
16 #include <torch/library.h>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/_empty_affine_quantized.h>
23 #include <ATen/ops/_empty_affine_quantized_native.h>
24 #include <ATen/ops/empty_like.h>
25 #endif
26 
27 #include <algorithm>
28 
29 namespace at {
30 namespace native {
31 
32 DEFINE_DISPATCH(qmul_relu_stub);
33 DEFINE_DISPATCH(qmul_stub);
34 
35 namespace {
36 
check_inputs(const Tensor & qa,const Tensor & qb)37 inline void check_inputs(const Tensor& qa, const Tensor& qb) {
38   TORCH_CHECK(qa.qscheme() == kPerTensorAffine,
39               "Only per tensor quantization is supported in Mul.");
40   TORCH_CHECK(qa.scalar_type() == qb.scalar_type(),
41               "Mul operands should have same data type.");
42   TORCH_CHECK(qa.qscheme() == qb.qscheme(),
43               "Both inputs to Mul must have the same quantization scheme.");
44 }
45 
46 // Note: out is assumed to be the same size as self and other.
47 // Note: Multiplication is only supported when self, other, out are of the same
48 //       dtype.
49 template <bool ReLUFused = false>
_mul_out(Tensor & out,const Tensor & self,const Tensor & other)50 Tensor _mul_out(Tensor& out, const Tensor& self, const Tensor& other) {
51   if (ReLUFused) {
52     qmul_relu_stub(self.device().type(), out, self, other);
53   } else {
54     qmul_stub(self.device().type(), out, self, other);
55   }
56   return out;
57 }
58 
59 #ifdef USE_XNNPACK
60 template <typename scalar_t, bool ReLUFused = false>
_mul_out_xnnpack(const Tensor & self,const Tensor & other,double output_scale,int64_t output_zero_point)61 Tensor _mul_out_xnnpack(
62     const Tensor& self,
63     const Tensor& other,
64     double output_scale,
65     int64_t output_zero_point) {
66   using underlying_t = typename scalar_t::underlying;
67 
68   const string func_name = "xnnp_mul()";
69   TORCH_CHECK(self.ndimension() > 0, func_name, ": Got empty input tensor.");
70   TORCH_CHECK(
71       at::native::xnnpack::available(), func_name, ": XNNPACK is not available")
72 
73   // using qa memory format for qb to allow xnnpack kernel to flatten all the
74   // dims
75   auto qa_mem_format = self.suggest_memory_format();
76   Tensor self_contig = self.contiguous(qa_mem_format);
77   Tensor other_contig = other.contiguous(qa_mem_format);
78 
79   Tensor out = at::native::empty_affine_quantized(
80       at::infer_size_dimvector(self_contig.sizes(), other_contig.sizes()),
81       self.scalar_type(),
82       std::nullopt /* layout */,
83       kCPU,
84       std::nullopt /* pin_memory */,
85       output_scale,
86       output_zero_point,
87       qa_mem_format);
88 
89   if (self_contig.size(0) == 0) {
90     return out;
91   }
92 
93   int64_t self_zero_point = self_contig.q_zero_point();
94   double self_scale = self_contig.q_scale();
95   int64_t other_zero_point = other_contig.q_zero_point();
96   double other_scale = other_contig.q_scale();
97 
98   int64_t output_min = std::numeric_limits<underlying_t>::min();
99   int64_t output_max = std::numeric_limits<underlying_t>::max();
100 
101   if(ReLUFused) {
102     /*
103      * FIXME: use activationLimits<T>()
104      * With <T>, MSVC runs into "error C3862: identifier activationLimits not
105      * found".
106      */
107     constexpr int64_t qmin = std::numeric_limits<underlying_t>::min();
108     constexpr int64_t qmax = std::numeric_limits<underlying_t>::max();
109     int64_t qvalue = static_cast<int64_t>(output_zero_point);
110     qvalue = std::max<int64_t>(qvalue, qmin);
111     output_min = static_cast<underlying_t>(std::min<int64_t>(qvalue, qmax));
112   }
113 
114   xnn_operator_t xnnp_op = nullptr;
115   xnnpack_operator xnnp_qmul_operator;
116 
117   // create xnnpack multiply operator ...
118   auto status = xnn_create_multiply_nd_qs8(
119       self_zero_point,
120       self_scale,
121       other_zero_point,
122       other_scale,
123       static_cast<underlying_t>(output_zero_point),
124       static_cast<float>(output_scale),
125       output_min,
126       output_max,
127       0,
128       &xnnp_op);
129 
130   TORCH_CHECK(
131       status == xnn_status_success,
132       func_name,
133       ": xnn create operator failed(",
134       status,
135       ")!");
136   xnnp_qmul_operator = xnnpack_operator(xnnp_op);
137 
138 
139   const auto self_shape = xnnp_utils::get_mem_format_aware_shape(self_contig);
140   const auto other_shape = xnnp_utils::get_mem_format_aware_shape(other_contig);
141 
142   // reshape operator
143   status = xnn_reshape_multiply_nd_qs8(
144       xnnp_qmul_operator.get(),
145       self_shape.size(),
146       self_shape.data(),
147       other_shape.size(),
148       other_shape.data(),
149       caffe2::pthreadpool_());
150 
151   TORCH_CHECK(
152       status == xnn_status_success,
153       func_name,
154       ": xnn reshape operator failed(",
155       status,
156       ")!");
157 
158   // set up operator
159   status = xnn_setup_multiply_nd_qs8(
160       xnnp_qmul_operator.get(),
161       reinterpret_cast<const underlying_t*>(self_contig.data_ptr<scalar_t>()),
162       reinterpret_cast<const underlying_t*>(other_contig.data_ptr<scalar_t>()),
163       reinterpret_cast<underlying_t*>(out.data_ptr<scalar_t>())
164   );
165 
166   TORCH_CHECK(
167       status == xnn_status_success,
168       func_name,
169       ": xnn setup operator failed(",
170       status,
171       ")!");
172 
173   // Run the operator
174   status = xnn_run_operator(
175       xnnp_qmul_operator.get(), /* xnn_operator_t op */
176       caffe2::pthreadpool_()); /* pthreadpool_t threadpool */
177   TORCH_CHECK(
178       status == xnn_status_success,
179       func_name,
180       ": xnn run operator failed(",
181       status,
182       ")");
183 
184   return out;
185 }
186 
187 #endif // use XNNPACK
188 
189 template <bool ReLUFused = false>
_mul_scalar_out(Tensor & out,const Tensor & self,const Scalar & other)190 Tensor _mul_scalar_out(Tensor& out, const Tensor& self, const Scalar& other) {
191   int64_t self_zero_point = self.q_zero_point();
192   double self_scale = self.q_scale();
193   double other_val = other.toDouble();
194 
195   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
196   double scale_prime;
197   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
198   int64_t zero_point_prime;
199 
200   AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qmul_scalar", [&]() {
201     // NOLINTNEXTLINE(bugprone-signed-char-misuse)
202     int64_t q_min = std::numeric_limits<underlying_t>::min();
203     int64_t q_max = std::numeric_limits<underlying_t>::max();
204 
205     if (other_val > 0.0) {
206       scale_prime = other_val * self_scale;
207       zero_point_prime = self_zero_point;
208 
209       if (ReLUFused) {
210         qrelu_stub(self.device().type(), self, out);
211       } else {
212         out.copy_(self);
213       }
214       set_quantizer_(out, make_per_tensor_affine_quantizer(
215           scale_prime, zero_point_prime, self.scalar_type()));
216     } else if (other_val == 0.0) {
217       scale_prime = 1.0;
218       zero_point_prime = 0;
219 
220       // Strided "memset"
221       // Set all values to 0
222       auto iter = TensorIterator::unary_op(out, self);
223       cpu_kernel_vec(
224           iter,
225           [&](scalar_t a) -> scalar_t { return scalar_t(0); },
226           [&](Vectorized<scalar_t> vec) -> Vectorized<scalar_t> {
227             return Vectorized<scalar_t>(scalar_t(0));
228           });
229       set_quantizer_(out, make_per_tensor_affine_quantizer(
230           scale_prime, zero_point_prime, self.scalar_type()));
231     } else /* other_val < 0.0 */ {
232       scale_prime = std::abs(other_val) * self_scale;
233       zero_point_prime = q_max - (self_zero_point - q_min);
234 
235       // xq' = q_max + q_min - x_q
236       auto iter = TensorIterator::unary_op(out, self);
237       cpu_kernel(
238           iter,
239           [&](scalar_t a) -> scalar_t {
240             a = scalar_t(underlying_t(q_max + q_min - a.val_));
241             if (ReLUFused) {
242               a = scalar_t(std::max(a.val_, underlying_t(zero_point_prime)));
243             }
244             return a;
245           });
246       set_quantizer_(out, make_per_tensor_affine_quantizer(
247           scale_prime, zero_point_prime, self.scalar_type()));
248     }
249   });
250 
251   return out;
252   }
253 
254 template <bool ReLUFused = false>
255 class QMul final {
256  public:
run(Tensor qa,Tensor qb,double scale,int64_t zero_point)257   static Tensor run(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
258     check_inputs(qa, qb);
259 #ifdef USE_XNNPACK
260     int64_t q_max = std::numeric_limits<c10::qint8::underlying>::max();
261     if (zero_point < q_max && qa.scalar_type() == kQInt8) {
262       return _mul_out_xnnpack<c10::qint8, ReLUFused>(qa, qb, scale, zero_point);
263     }
264 #endif // USE_XNNPACK
265 
266     auto qc = at::_empty_affine_quantized(
267         infer_size_dimvector(qa.sizes(), qb.sizes()),
268         at::device(kCPU).dtype(qa.scalar_type()),
269         scale,
270         zero_point,
271         qa.suggest_memory_format());
272 
273     return _mul_out<ReLUFused>(qc, qa, qb);
274   }
275 };
276 
277 template <bool ReLUFused = false>
278 class QMulOut final {
279  public:
run(at::Tensor qa,at::Tensor qb,Tensor out)280   static Tensor run(at::Tensor qa, at::Tensor qb, Tensor out) {
281     check_inputs(qa, qb);
282     return _mul_out<ReLUFused>(out, qa, qb);
283   }
284 };
285 
286 
287 template <bool ReLUFused = false>
288 class QMulScalar final {
289  public:
run(Tensor qa,const Scalar & b)290   static Tensor run(Tensor qa, const Scalar& b) {
291     TORCH_CHECK(qa.qscheme() == kPerTensorAffine ||
292               qa.qscheme() == kPerTensorSymmetric,
293               "Only per tensor quantization is supported in Mul.");
294     auto qc = at::empty_like(qa, qa.suggest_memory_format());
295     return _mul_scalar_out<ReLUFused>(qc, qa, b);
296   }
297 };
298 
299 template <bool ReLUFused = false>
300 class QMulScalar2 final {
301  public:
run(const Scalar & b,Tensor qa)302   static Tensor run(const Scalar& b, Tensor qa) {
303     TORCH_CHECK(qa.qscheme() == kPerTensorAffine ||
304               qa.qscheme() == kPerTensorSymmetric,
305               "Only per tensor quantization is supported in Mul.");
306     auto qc = at::empty_like(qa, qa.suggest_memory_format());
307     return _mul_scalar_out<ReLUFused>(qc, qa, b);
308   }
309 };
310 
311 template <bool ReLUFused = false>
312 class QMulScalarOut final {
313  public:
run(Tensor qa,const Scalar & b,Tensor out)314   static Tensor run(Tensor qa, const Scalar& b, Tensor out) {
315     check_inputs(qa, out);
316     return _mul_scalar_out<ReLUFused>(out, qa, b);
317   }
318 };
319 
320 // `torch.jit.trace` will trace Scalar as Tensor
321 // This can be removed after broadcast is supported and
322 // all variations of `quantized::mul` is merged into `quantized::mul`
323 template <bool ReLUFused = false>
324 class QMulScalarTensor final {
325  public:
run(Tensor qa,Tensor b)326   static Tensor run(Tensor qa, Tensor b) {
327     TORCH_CHECK(qa.qscheme() == kPerTensorAffine ||
328               qa.qscheme() == kPerTensorSymmetric,
329               "Only per tensor quantization is supported in Mul.");
330     auto qc = at::empty_like(qa, qa.suggest_memory_format());
331     return _mul_scalar_out<ReLUFused>(qc, qa, b.item());
332   }
333 };
334 
335 // `torch.jit.trace` will trace Scalar as Tensor
336 // This can be removed after broadcast is supported and
337 // all variations of `quantized::mul` is merged into `quantized::mul`
338 template <bool ReLUFused = false>
339 class QMulScalarTensorOut final {
340  public:
run(Tensor qa,Tensor b,Tensor out)341   static Tensor run(Tensor qa, Tensor b, Tensor out) {
342     check_inputs(qa, out);
343     return _mul_scalar_out<ReLUFused>(out, qa, b.item());
344   }
345 };
346 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)347 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
348   m.impl(TORCH_SELECTIVE_NAME("quantized::mul"),                 TORCH_FN(QMul</*ReLUFused=*/false>::run));
349   m.impl(TORCH_SELECTIVE_NAME("quantized::mul.out"),             TORCH_FN(QMulOut</*ReLUFused=*/false>::run));
350   m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar"),          TORCH_FN(QMulScalar</*ReLUFused=*/false>::run));
351   m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar2"),          TORCH_FN(QMulScalar2</*ReLUFused=*/false>::run));
352   m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar_out"),      TORCH_FN(QMulScalarOut</*ReLUFused=*/false>::run));
353   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu"),            TORCH_FN(QMul</*ReLUFused=*/true>::run));
354   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.out"),        TORCH_FN(QMulOut</*ReLUFused=*/true>::run));
355   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar"),     TORCH_FN(QMulScalar</*ReLUFused=*/true>::run));
356   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar2"),     TORCH_FN(QMulScalar2</*ReLUFused=*/true>::run));
357   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar_out"), TORCH_FN(QMulScalarOut</*ReLUFused=*/true>::run));
358   // deprecated functions, kept for backward compatibility
359   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_out"),             TORCH_FN(QMulOut</*ReLUFused=*/false>::run));
360   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu_out"),        TORCH_FN(QMulOut</*ReLUFused=*/true>::run));
361   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar"),          TORCH_FN(QMulScalar</*ReLUFused=*/false>::run));
362   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu"),     TORCH_FN(QMulScalar</*ReLUFused=*/true>::run));
363   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_out"),      TORCH_FN(QMulScalarOut</*ReLUFused=*/false>::run));
364   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu_out"), TORCH_FN(QMulScalarOut</*ReLUFused=*/true>::run));
365   // TODO: remove after broadcasting is supported
366   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar.Tensor"), TORCH_FN(QMulScalarTensor</*ReLUFused=*/false>::run));
367   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu.Tensor"), TORCH_FN(QMulScalarTensor</*ReLUFused=*/true>::run));
368   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_out.Tensor"), TORCH_FN(QMulScalarTensorOut</*ReLUFused=*/false>::run));
369   m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu_out.Tensor"), TORCH_FN(QMulScalarTensorOut</*ReLUFused=*/true>::run));
370 }
371 
372 }  // namespace
373 }}  // namespace at::native
374