xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <algorithm>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/core/ivalue.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/SmallVector.h>
8 #include <ATen/native/quantized/PackedParams.h>
9 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
10 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
11 #include <ATen/native/quantized/cpu/OnednnUtils.h>
12 #include <ATen/native/quantized/cpu/QuantUtils.h>
13 #include <c10/util/irange.h>
14 #include <torch/library.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #else
19 #include <ATen/ops/dequantize.h>                           // for dequantize
20 #include <ATen/ops/quantize_per_tensor.h>
21 #endif
22 
23 #ifdef USE_FBGEMM
24 
25 template <int kSpatialDim>
apply_dynamic(const at::Tensor & input,bool reduce_range)26 at::Tensor PackedConvWeight<kSpatialDim>::apply_dynamic(
27     const at::Tensor& input,
28     bool reduce_range) {
29   TORCH_CHECK(
30       fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
31 
32   float x_min, x_max;
33   fbgemm::FindMinMax(
34       /*m=*/input.data_ptr<float>(),
35       /*min=*/&x_min,
36       /*max=*/&x_max,
37       /*len=*/input.numel());
38 
39   // Input tensor is quantized as 8-bit unsigned values
40   static constexpr int precision = 8;
41   static constexpr bool is_signed = false;
42 
43   // Calculate scale and zero point for quantization of input tensor
44   auto q_params = quant_utils::ChooseQuantizationParams(
45       /*min=*/x_min,
46       /*max=*/x_max,
47       /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
48       /*qmax=*/
49       is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
50       /*preserve_sparsity=*/false,
51       /*force_scale_power_of_two=*/false,
52       /*reduce_range=*/reduce_range);
53 
54   // Quantize input
55   at::Tensor q_input = at::quantize_per_tensor(
56       input, q_params.scale, q_params.zero_point, c10::kQUInt8);
57 
58   at::Tensor out =
59       apply_impl<false>(q_input, q_params.scale, q_params.zero_point);
60 
61   return at::dequantize(out); // TODO: optimized kernel that outputs fp32 so
62                               // this step isn't necessary
63 }
64 
65 template at::Tensor PackedConvWeight<2>::apply_dynamic(
66     const at::Tensor& input,
67     bool reduce_range);
68 
69 template at::Tensor PackedConvWeight<3>::apply_dynamic(
70     const at::Tensor& input,
71     bool reduce_range);
72 
73 #endif // USE_FBGEMM
74 
75 #ifdef USE_PYTORCH_QNNPACK
76 
77 template <int kSpatialDim>
apply_dynamic(const at::Tensor & input,bool reduce_range)78 at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_dynamic(
79     const at::Tensor& input,
80     bool reduce_range) {
81   if (reduce_range) {
82     TORCH_WARN("Currently, qnnpack incorrectly ignores reduce_range when it is set to true; this may change in a future release.");
83   }
84 
85   // On empty input, no output data will be generated,
86   // so use arbitrary qparams.
87   float x_min = 0;
88   float x_max = 0;
89   // Otherwise...
90   if (input.numel() > 0) {
91     x_min = input.min().item<float>();
92     x_max = input.max().item<float>();
93   }
94 
95   // Input tensor is quantized as 8-bit unsigned values
96   static constexpr int precision = 8;
97   static constexpr bool is_signed = false;
98 
99   // Calculate scale and zero point for quantization of input tensor
100   auto q_params = quant_utils::ChooseQuantizationParams(
101       /*min=*/x_min,
102       /*max=*/x_max,
103       /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
104       /*qmax=*/
105       is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
106       /*preserve_sparsity=*/false,
107       /*force_scale_power_of_two=*/false,
108       /*reduce_range=*/false); // note: this is set to false rather than
109                                // reduce_range for qnnpack
110 
111   // Quantize input
112   at::Tensor q_input = at::quantize_per_tensor(
113       input, q_params.scale, q_params.zero_point, c10::kQUInt8);
114 
115   at::Tensor out =
116       apply_impl<false>(q_input, q_params.scale, q_params.zero_point);
117 
118   return at::dequantize(out); // TODO: optimized kernel that outputs fp32 so
119                               // this step isn't necessary
120 }
121 
122 template at::Tensor PackedConvWeightsQnnp<2>::apply_dynamic(
123     const at::Tensor& input,
124     bool reduce_range);
125 
126 template at::Tensor PackedConvWeightsQnnp<3>::apply_dynamic(
127     const at::Tensor& input,
128     bool reduce_range);
129 
130 #endif // USE_PYTORCH_QNNPACK
131 
132 #if AT_MKLDNN_ENABLED()
133 
134 template <int kSpatialDim>
apply_dynamic(const at::Tensor & input,bool reduce_range)135 at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_dynamic(
136     const at::Tensor& input,
137     bool reduce_range) {
138 
139   // Find min/max of input
140   float x_max = 0, x_min = 0;
141   if (input.numel() > 0) {
142     x_min = input.min().item<float>();
143     x_max = input.max().item<float>();
144   }
145 
146   // Input tensor is quantized as 8-bit unsigned values
147   static constexpr int precision = 8;
148   static constexpr bool is_signed = false;
149 
150   // Calculate scale and zero point for quantization of input tensor
151   auto q_params = quant_utils::ChooseQuantizationParams(
152       /*min=*/x_min,
153       /*max=*/x_max,
154       /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
155       /*qmax=*/
156       is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
157       /*preserve_sparsity=*/false,
158       /*force_scale_power_of_two=*/false,
159       /*reduce_range=*/reduce_range);
160 
161   // Quantize input
162   at::Tensor q_input = at::quantize_per_tensor(
163       input, q_params.scale, q_params.zero_point, c10::kQUInt8);
164 
165   at::Tensor out =
166       apply_impl<false>(q_input, /*accum*/std::nullopt, q_params.scale, q_params.zero_point);
167 
168   // TODO: Modify ideep to allow fp32 input & output
169   // to avoid explicit `quantize - dequantize`
170   return at::dequantize(out);
171 }
172 
173 template at::Tensor PackedConvWeightsOnednn<2>::apply_dynamic(
174     const at::Tensor& input,
175     bool reduce_range);
176 
177 template at::Tensor PackedConvWeightsOnednn<3>::apply_dynamic(
178     const at::Tensor& input,
179     bool reduce_range);
180 
181 #endif // AT_MKLDNN_ENABLED()
182 
183 namespace at {
184 namespace native {
185 namespace {
186 
187 // note: this works for both Conv and ConvT due to transpose()
188 template <int kSpatialDim>
189 class QConvDynamicInt8 final {
190  public:
run(at::Tensor input,const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & packed_weight,bool reduce_range)191   static at::Tensor run(
192       at::Tensor input,
193       const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>&
194           packed_weight,
195       bool reduce_range) {
196     return packed_weight->apply_dynamic(input, reduce_range);
197   }
198 };
199 
200 // note: this works for both Conv and ConvT due to transpose()
201 class QConv1dDynamicInt8 final {
202  public:
run(at::Tensor input,const c10::intrusive_ptr<ConvPackedParamsBase<2>> & packed_weight,bool reduce_range)203   static at::Tensor run(
204       at::Tensor input,
205       const c10::intrusive_ptr<ConvPackedParamsBase<2>>& packed_weight,
206       bool reduce_range) {
207     at::Tensor output;
208     // N, C, L -> N, C, 1, L
209     input = input.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
210     output = packed_weight->apply_dynamic(input, reduce_range);
211     // N, C, 1, L -> N, C, L
212     return output.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
213   }
214 };
215 
TORCH_LIBRARY_IMPL(quantized,CPU,m)216 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
217   m.impl(
218       TORCH_SELECTIVE_NAME("quantized::conv1d_dynamic"),
219       TORCH_FN(QConv1dDynamicInt8::run));
220   m.impl(
221       TORCH_SELECTIVE_NAME("quantized::conv2d_dynamic"),
222       TORCH_FN(QConvDynamicInt8<2>::run));
223   m.impl(
224       TORCH_SELECTIVE_NAME("quantized::conv3d_dynamic"),
225       TORCH_FN(QConvDynamicInt8<3>::run));
226 
227   // transpose
228   m.impl(
229       TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_dynamic"),
230       TORCH_FN(QConv1dDynamicInt8::run));
231   m.impl(
232       TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_dynamic"),
233       TORCH_FN(QConvDynamicInt8<2>::run));
234   m.impl(
235       TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_dynamic"),
236       TORCH_FN(QConvDynamicInt8<3>::run));
237 }
238 
239 } // namespace
240 } // namespace native
241 } // namespace at
242