xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_mixed_mm.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/vec_ops.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 namespace native {
15 
16 using Tensor = exec_aten::Tensor;
17 
check_quantized_mixed_mm_args(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,Tensor & out)18 bool check_quantized_mixed_mm_args(
19     const Tensor& in,
20     const Tensor& weight,
21     const Tensor& weight_scales,
22     const exec_aten::optional<Tensor>& opt_weight_zero_points,
23     Tensor& out) {
24   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
25   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2));
26   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight_scales, 1));
27   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
28 
29   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 0));
30   ET_LOG_AND_RETURN_IF_FALSE(
31       tensors_have_same_size_at_dims(weight_scales, 0, weight, 0));
32 
33   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight_scales, out));
34   ET_LOG_MSG_AND_RETURN_IF_FALSE(
35       weight.scalar_type() == ScalarType::Char, "weight dtype must be int8");
36   ET_LOG_MSG_AND_RETURN_IF_FALSE(
37       in.scalar_type() == ScalarType::Float ||
38           in.scalar_type() == ScalarType::Half,
39       "input dtype must be Float or Half");
40 
41   if (opt_weight_zero_points.has_value()) {
42     ET_LOG_AND_RETURN_IF_FALSE(
43         tensors_have_same_shape(opt_weight_zero_points.value(), weight_scales));
44     ET_LOG_AND_RETURN_IF_FALSE(
45         tensors_have_same_dtype(opt_weight_zero_points.value(), in));
46   }
47 
48   // Support for non-null zero points is not implemented yet.
49   ET_LOG_MSG_AND_RETURN_IF_FALSE(
50       !opt_weight_zero_points.has_value(), "zero points not supported yet.");
51   return true;
52 }
53 
quantized_mixed_mm_out(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,Tensor & out)54 Tensor& quantized_mixed_mm_out(
55     const Tensor& in,
56     const Tensor& weight,
57     const Tensor& weight_scales,
58     const exec_aten::optional<Tensor>& opt_weight_zero_points,
59     Tensor& out) {
60   ET_CHECK(check_quantized_mixed_mm_args(
61       in, weight, weight_scales, opt_weight_zero_points, out));
62 
63   size_t output_ndim = 2;
64   exec_aten::SizesType output_sizes[kTensorDimensionLimit];
65   output_sizes[0] = in.size(0);
66   output_sizes[1] = weight.size(1);
67 
68   ET_CHECK(resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok);
69 
70   constexpr auto name = "quantized_decomposed::mixed_mm.out";
71 
72   ET_SWITCH_TWO_TYPES(Float, Half, in.scalar_type(), ctx, name, CTYPE, [&]() {
73     size_t m = in.size(0);
74     size_t n = in.size(1);
75     size_t p = weight.size(1);
76 
77     vec_quantized_matmul_int8<CTYPE>(
78         out.mutable_data_ptr<CTYPE>(),
79         in.const_data_ptr<CTYPE>(),
80         weight.const_data_ptr<int8_t>(),
81         weight_scales.const_data_ptr<CTYPE>(),
82         m,
83         n,
84         p);
85   });
86 
87   return out;
88 }
89 
quantized_mixed_mm_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,Tensor & out)90 Tensor& quantized_mixed_mm_out(
91     KernelRuntimeContext& ctx,
92     const Tensor& in,
93     const Tensor& weight,
94     const Tensor& weight_scales,
95     const exec_aten::optional<Tensor>& opt_weight_zero_points,
96     Tensor& out) {
97   // TODO(mcandales): Remove the need for this wrapper
98   (void)ctx;
99   return quantized_mixed_mm_out(
100       in, weight, weight_scales, opt_weight_zero_points, out);
101 }
102 
103 } // namespace native
104 } // namespace executor
105 } // namespace torch
106