xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/Normalization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <torch/library.h>
5 #include <ATen/native/quantized/cpu/QuantizedOps.h>
6 #include <c10/util/irange.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_empty_affine_quantized.h>
13 #include <ATen/ops/empty_like.h>
14 #include <ATen/ops/quantized_batch_norm_native.h>
15 #endif
16 
17 #include <algorithm>
18 
19 namespace at {
20 namespace native {
21 
22 DEFINE_DISPATCH(qbatch_norm_stub);
23 DEFINE_DISPATCH(qbatch_norm_relu_stub);
24 
25 namespace {
compute_fused_params(const int64_t channels,const float * weight_data,const float * bias_data,const float * mean_data,const float * var_data,double eps,double input_scale,double output_scale,float * alpha_data,float * beta_data)26 void compute_fused_params(
27     const int64_t channels,
28     const float* weight_data,
29     const float* bias_data,
30     const float* mean_data,
31     const float* var_data,
32     double eps,
33     double input_scale,
34     double output_scale,
35     float* alpha_data,
36     float* beta_data) {
37   // Batch Normalization
38   // output(n, c, h, w)
39   //     = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
40   //         + bias(c)
41   // We factor out inv_sigma(c) = 1 / sqrt(var(c) + eps).
42   for (const auto c : c10::irange(channels)) {
43     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
44     float inv_sigma = 1.0 / std::sqrt(var_data[c] + static_cast<float>(eps));
45     float weight_v = weight_data ? weight_data[c] : 1;
46     float bias_v = bias_data ? bias_data[c] : 0;
47     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
48     alpha_data[c] = inv_sigma * weight_v * (input_scale / output_scale);
49     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
50     beta_data[c] = (bias_v - mean_data[c] * inv_sigma * weight_v) / output_scale;
51   }
52 }
53 
54 template <bool ReluFused>
q_batch_norm1d_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)55 Tensor q_batch_norm1d_impl(
56     Tensor qx,
57     std::optional<Tensor> mb_weight,
58     std::optional<Tensor> mb_bias,
59     Tensor mean,
60     Tensor var,
61     double eps,
62     double output_scale,
63     int64_t output_zero_point) {
64 
65   TORCH_CHECK(mb_weight.has_value(), "Weight must be provided");
66   TORCH_CHECK(mb_bias.has_value(), "Bias must be provided");
67   const auto& weight = *mb_weight;
68   const auto& bias = *mb_bias;
69 
70   if (qx.numel() == 0) {
71     auto out = qx.clone();
72     return out;
73   }
74   int64_t ndim = qx.dim();
75   TORCH_CHECK(ndim == 2 || ndim == 3, "Expecting the input tensor of rank 2 or 3.");
76   const int64_t N = qx.size(0);
77   const int64_t C = qx.size(1);
78   const int64_t H = ndim == 3 ? qx.size(2) : 1;
79 
80   TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
81   TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
82 
83   const float* weight_data = weight.template const_data_ptr<float>();
84   const float* bias_data = bias.template const_data_ptr<float>();
85 
86   TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
87   TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
88 
89   Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
90   Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
91   float* alpha_data = alpha.mutable_data_ptr<float>();
92   float* beta_data = beta.data_ptr<float>();
93 
94   const float* mean_data = mean.template const_data_ptr<float>();
95   const float* var_data = var.template const_data_ptr<float>();
96 
97   if (ndim == 2) {
98     // create a fake H and W dimension so we can use NHWC
99     qx = qx.unsqueeze(-1).unsqueeze(-1);
100   } else {
101     // create a fake W dimension so we can use NHWC
102     qx = qx.unsqueeze(-1);
103   }
104 
105   auto oSizes = qx.sizes();
106   auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
107   Tensor qy = at::_empty_affine_quantized(
108       oSizes,
109       at::device(kCPU)
110         .dtype(qx_nhwc.scalar_type())
111         .memory_format(MemoryFormat::ChannelsLast),
112       output_scale,
113       output_zero_point,
114       std::nullopt);
115 
116   compute_fused_params(
117       C,
118       weight_data,
119       bias_data,
120       mean_data,
121       var_data,
122       eps,
123       qx.q_scale(),
124       output_scale,
125       alpha_data,
126       beta_data);
127   if (ReluFused) {
128     qbatch_norm_relu_stub(
129         qx.device().type(),
130         N,
131         C,
132         H,
133         qx.q_zero_point(),
134         output_zero_point,
135         qx_nhwc,
136         alpha,
137         beta,
138         qy);
139   } else {
140     qbatch_norm_stub(
141         qx.device().type(),
142         N,
143         C,
144         H,
145         qx.q_zero_point(),
146         output_zero_point,
147         qx_nhwc,
148         alpha,
149         beta,
150         qy);
151   }
152   // Remove the fake dimension, and go back to contiguous format
153   // (since there is no 4th channel). Note, this has a performance
154   // cost.
155   Tensor result = qy.contiguous(MemoryFormat::Contiguous).squeeze(-1);
156   if (ndim == 2) {
157     result = result.squeeze(-1);
158   }
159   return result;
160 }
161 
162 template <bool ReluFused>
q_batch_norm2d_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)163 Tensor q_batch_norm2d_impl(
164     Tensor qx,
165     std::optional<Tensor> mb_weight,
166     std::optional<Tensor> mb_bias,
167     Tensor mean,
168     Tensor var,
169     double eps,
170     double output_scale,
171     int64_t output_zero_point) {
172 
173   TORCH_CHECK(mb_weight.has_value(), "Weight must be provided");
174   TORCH_CHECK(mb_bias.has_value(), "Bias must be provided");
175   const auto& weight = *mb_weight;
176   const auto& bias = *mb_bias;
177 
178   if (qx.numel() == 0) {
179     auto out = qx.clone();
180     return out;
181   }
182   int64_t ndim = qx.dim();
183   TORCH_CHECK(ndim == 4, "Expecting the input tensor of rank 4.");
184   const int64_t N = qx.size(0);
185   const int64_t C = qx.size(1);
186   const int64_t H = qx.size(2);
187   const int64_t W = qx.size(3);
188 
189   TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
190   TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
191 
192   const float* weight_data = weight.template const_data_ptr<float>();
193   const float* bias_data = bias.template const_data_ptr<float>();
194 
195   TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
196   TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
197 
198   Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
199   Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
200   float* alpha_data = alpha.mutable_data_ptr<float>();
201   float* beta_data = beta.data_ptr<float>();
202 
203   const float* mean_data = mean.template const_data_ptr<float>();
204   const float* var_data = var.template const_data_ptr<float>();
205 
206   auto oSizes = qx.sizes();
207   auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
208   Tensor qy = at::_empty_affine_quantized(
209       oSizes,
210       at::device(kCPU)
211         .dtype(qx_nhwc.scalar_type())
212         .memory_format(MemoryFormat::ChannelsLast),
213       output_scale,
214       output_zero_point,
215       std::nullopt);
216 
217   compute_fused_params(
218       C,
219       weight_data,
220       bias_data,
221       mean_data,
222       var_data,
223       eps,
224       qx.q_scale(),
225       output_scale,
226       alpha_data,
227       beta_data);
228   if (ReluFused) {
229     qbatch_norm_relu_stub(
230         qx.device().type(),
231         N,
232         C,
233         H * W,
234         qx.q_zero_point(),
235         output_zero_point,
236         qx_nhwc,
237         alpha,
238         beta,
239         qy);
240   } else {
241     qbatch_norm_stub(
242         qx.device().type(),
243         N,
244         C,
245         H * W,
246         qx.q_zero_point(),
247         output_zero_point,
248         qx_nhwc,
249         alpha,
250         beta,
251         qy);
252   }
253   return qy;
254 }
255 
256 template <bool ReluFused>
q_batch_norm3d_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)257 Tensor q_batch_norm3d_impl(
258     Tensor qx,
259     std::optional<Tensor> mb_weight,
260     std::optional<Tensor> mb_bias,
261     Tensor mean,
262     Tensor var,
263     double eps,
264     double output_scale,
265     int64_t output_zero_point) {
266 
267   TORCH_CHECK(mb_weight.has_value(), "Weight must be provided")
268   TORCH_CHECK(mb_bias.has_value(), "Bias must be provided")
269 
270   const auto& weight = *mb_weight;
271   const auto& bias = *mb_bias;
272 
273   if (qx.numel() == 0) {
274     auto out = qx.clone();
275     return out;
276   }
277   int64_t ndim = qx.dim();
278   TORCH_CHECK(ndim == 5, "Expecting the input tensor of rank 5.");
279   const int64_t N = qx.size(0);
280   const int64_t C = qx.size(1);
281   const int64_t D = qx.size(2);
282   const int64_t H = qx.size(3);
283   const int64_t W = qx.size(4);
284 
285   TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
286   TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
287 
288   const float* weight_data = weight.template const_data_ptr<float>();
289   const float* bias_data = bias.template const_data_ptr<float>();
290 
291   TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
292   TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
293 
294   Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
295   Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
296   float* alpha_data = alpha.mutable_data_ptr<float>();
297   float* beta_data = beta.data_ptr<float>();
298 
299   const float* mean_data = mean.template const_data_ptr<float>();
300   const float* var_data = var.template const_data_ptr<float>();
301 
302   auto oSizes = qx.sizes();
303   auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast3d);
304   Tensor qy = at::_empty_affine_quantized(
305       oSizes,
306       at::device(kCPU)
307         .dtype(qx_nhwc.scalar_type())
308         .memory_format(MemoryFormat::ChannelsLast3d),
309       output_scale,
310       output_zero_point,
311       std::nullopt);
312 
313   compute_fused_params(
314       C,
315       weight_data,
316       bias_data,
317       mean_data,
318       var_data,
319       eps,
320       qx.q_scale(),
321       output_scale,
322       alpha_data,
323       beta_data);
324 
325   if (ReluFused) {
326     qbatch_norm_relu_stub(
327         qx.device().type(),
328         N,
329         C,
330         D * H * W,
331         qx.q_zero_point(),
332         output_zero_point,
333         qx_nhwc,
334         alpha,
335         beta,
336         qy);
337   } else {
338     qbatch_norm_stub(
339         qx.device().type(),
340         N,
341         C,
342         D * H * W,
343         qx.q_zero_point(),
344         output_zero_point,
345         qx_nhwc,
346         alpha,
347         beta,
348         qy);
349   }
350   return qy;
351 }
352 
353 template <bool ReluFused>
q_batch_norm_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)354 Tensor q_batch_norm_impl(
355     Tensor qx,
356     std::optional<Tensor> mb_weight,
357     std::optional<Tensor> mb_bias,
358     Tensor mean,
359     Tensor var,
360     double eps,
361     double output_scale,
362     int64_t output_zero_point) {
363   Tensor qy;
364   int64_t dim = qx.dim();
365   if (dim == 2 || dim == 3) {
366     qy = q_batch_norm1d_impl<ReluFused>(
367         qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
368   } else if (dim == 4) {
369     qy = q_batch_norm2d_impl<ReluFused>(
370         qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
371   } else if (dim == 5) {
372     qy = q_batch_norm3d_impl<ReluFused>(
373         qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
374   } else {
375     TORCH_CHECK(false, "quantized::batch_norm only support 2d, 3d, 4d or 5d inputs.");
376   }
377   return qy;
378 }
379 
380 } // namespace
381 
quantized_batch_norm(const Tensor & qx,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean,const Tensor & var,double eps,double output_scale,int64_t output_zero_point)382 Tensor quantized_batch_norm(
383     const Tensor& qx, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */,
384     const Tensor& mean /* optional */,
385     const Tensor& var /* optional */,
386     double eps,
387     double output_scale,
388     int64_t output_zero_point) {
389   // See [Note: hacky wrapper removal for optional tensor]
390   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
391   const Tensor& weight = *weight_maybe_owned;
392   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
393 
394   Tensor qy;
395   // TODO: this should arguably support 3d as well
396   qy = q_batch_norm2d_impl<false>(
397       qx,
398       weight.defined() ? std::make_optional(weight) : std::nullopt,
399       bias.defined() ? std::make_optional(bias) : std::nullopt,
400       mean, var, eps, output_scale, output_zero_point);
401   return qy;
402 }
403 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)404 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
405   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm"),        TORCH_FN(q_batch_norm_impl<false>));
406   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm_relu"),   TORCH_FN(q_batch_norm_impl<true>));
407   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm1d"),      TORCH_FN(q_batch_norm1d_impl<false>));
408   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm1d_relu"), TORCH_FN(q_batch_norm1d_impl<true>));
409   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm2d"),      TORCH_FN(q_batch_norm2d_impl<false>));
410   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm2d_relu"), TORCH_FN(q_batch_norm2d_impl<true>));
411   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm3d"),      TORCH_FN(q_batch_norm3d_impl<false>));
412   m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm3d_relu"), TORCH_FN(q_batch_norm3d_impl<true>));
413 }
414 
415 } // namespace native
416 } // namespace at
417