1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/vec_ops.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace torch {
13 namespace executor {
14 namespace native {
15
16 using Tensor = exec_aten::Tensor;
17
check_quantized_mixed_mm_args(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,Tensor & out)18 bool check_quantized_mixed_mm_args(
19 const Tensor& in,
20 const Tensor& weight,
21 const Tensor& weight_scales,
22 const exec_aten::optional<Tensor>& opt_weight_zero_points,
23 Tensor& out) {
24 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
25 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2));
26 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight_scales, 1));
27 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
28
29 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 0));
30 ET_LOG_AND_RETURN_IF_FALSE(
31 tensors_have_same_size_at_dims(weight_scales, 0, weight, 0));
32
33 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight_scales, out));
34 ET_LOG_MSG_AND_RETURN_IF_FALSE(
35 weight.scalar_type() == ScalarType::Char, "weight dtype must be int8");
36 ET_LOG_MSG_AND_RETURN_IF_FALSE(
37 in.scalar_type() == ScalarType::Float ||
38 in.scalar_type() == ScalarType::Half,
39 "input dtype must be Float or Half");
40
41 if (opt_weight_zero_points.has_value()) {
42 ET_LOG_AND_RETURN_IF_FALSE(
43 tensors_have_same_shape(opt_weight_zero_points.value(), weight_scales));
44 ET_LOG_AND_RETURN_IF_FALSE(
45 tensors_have_same_dtype(opt_weight_zero_points.value(), in));
46 }
47
48 // Support for non-null zero points is not implemented yet.
49 ET_LOG_MSG_AND_RETURN_IF_FALSE(
50 !opt_weight_zero_points.has_value(), "zero points not supported yet.");
51 return true;
52 }
53
quantized_mixed_mm_out(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,Tensor & out)54 Tensor& quantized_mixed_mm_out(
55 const Tensor& in,
56 const Tensor& weight,
57 const Tensor& weight_scales,
58 const exec_aten::optional<Tensor>& opt_weight_zero_points,
59 Tensor& out) {
60 ET_CHECK(check_quantized_mixed_mm_args(
61 in, weight, weight_scales, opt_weight_zero_points, out));
62
63 size_t output_ndim = 2;
64 exec_aten::SizesType output_sizes[kTensorDimensionLimit];
65 output_sizes[0] = in.size(0);
66 output_sizes[1] = weight.size(1);
67
68 ET_CHECK(resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok);
69
70 constexpr auto name = "quantized_decomposed::mixed_mm.out";
71
72 ET_SWITCH_TWO_TYPES(Float, Half, in.scalar_type(), ctx, name, CTYPE, [&]() {
73 size_t m = in.size(0);
74 size_t n = in.size(1);
75 size_t p = weight.size(1);
76
77 vec_quantized_matmul_int8<CTYPE>(
78 out.mutable_data_ptr<CTYPE>(),
79 in.const_data_ptr<CTYPE>(),
80 weight.const_data_ptr<int8_t>(),
81 weight_scales.const_data_ptr<CTYPE>(),
82 m,
83 n,
84 p);
85 });
86
87 return out;
88 }
89
quantized_mixed_mm_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,Tensor & out)90 Tensor& quantized_mixed_mm_out(
91 KernelRuntimeContext& ctx,
92 const Tensor& in,
93 const Tensor& weight,
94 const Tensor& weight_scales,
95 const exec_aten::optional<Tensor>& opt_weight_zero_points,
96 Tensor& out) {
97 // TODO(mcandales): Remove the need for this wrapper
98 (void)ctx;
99 return quantized_mixed_mm_out(
100 in, weight, weight_scales, opt_weight_zero_points, out);
101 }
102
103 } // namespace native
104 } // namespace executor
105 } // namespace torch
106