1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <torch/library.h>
5 #include <ATen/native/quantized/cpu/QuantizedOps.h>
6 #include <c10/util/irange.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_empty_affine_quantized.h>
13 #include <ATen/ops/empty_like.h>
14 #include <ATen/ops/quantized_batch_norm_native.h>
15 #endif
16
17 #include <algorithm>
18
19 namespace at {
20 namespace native {
21
22 DEFINE_DISPATCH(qbatch_norm_stub);
23 DEFINE_DISPATCH(qbatch_norm_relu_stub);
24
25 namespace {
compute_fused_params(const int64_t channels,const float * weight_data,const float * bias_data,const float * mean_data,const float * var_data,double eps,double input_scale,double output_scale,float * alpha_data,float * beta_data)26 void compute_fused_params(
27 const int64_t channels,
28 const float* weight_data,
29 const float* bias_data,
30 const float* mean_data,
31 const float* var_data,
32 double eps,
33 double input_scale,
34 double output_scale,
35 float* alpha_data,
36 float* beta_data) {
37 // Batch Normalization
38 // output(n, c, h, w)
39 // = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
40 // + bias(c)
41 // We factor out inv_sigma(c) = 1 / sqrt(var(c) + eps).
42 for (const auto c : c10::irange(channels)) {
43 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
44 float inv_sigma = 1.0 / std::sqrt(var_data[c] + static_cast<float>(eps));
45 float weight_v = weight_data ? weight_data[c] : 1;
46 float bias_v = bias_data ? bias_data[c] : 0;
47 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
48 alpha_data[c] = inv_sigma * weight_v * (input_scale / output_scale);
49 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
50 beta_data[c] = (bias_v - mean_data[c] * inv_sigma * weight_v) / output_scale;
51 }
52 }
53
54 template <bool ReluFused>
q_batch_norm1d_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)55 Tensor q_batch_norm1d_impl(
56 Tensor qx,
57 std::optional<Tensor> mb_weight,
58 std::optional<Tensor> mb_bias,
59 Tensor mean,
60 Tensor var,
61 double eps,
62 double output_scale,
63 int64_t output_zero_point) {
64
65 TORCH_CHECK(mb_weight.has_value(), "Weight must be provided");
66 TORCH_CHECK(mb_bias.has_value(), "Bias must be provided");
67 const auto& weight = *mb_weight;
68 const auto& bias = *mb_bias;
69
70 if (qx.numel() == 0) {
71 auto out = qx.clone();
72 return out;
73 }
74 int64_t ndim = qx.dim();
75 TORCH_CHECK(ndim == 2 || ndim == 3, "Expecting the input tensor of rank 2 or 3.");
76 const int64_t N = qx.size(0);
77 const int64_t C = qx.size(1);
78 const int64_t H = ndim == 3 ? qx.size(2) : 1;
79
80 TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
81 TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
82
83 const float* weight_data = weight.template const_data_ptr<float>();
84 const float* bias_data = bias.template const_data_ptr<float>();
85
86 TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
87 TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
88
89 Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
90 Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
91 float* alpha_data = alpha.mutable_data_ptr<float>();
92 float* beta_data = beta.data_ptr<float>();
93
94 const float* mean_data = mean.template const_data_ptr<float>();
95 const float* var_data = var.template const_data_ptr<float>();
96
97 if (ndim == 2) {
98 // create a fake H and W dimension so we can use NHWC
99 qx = qx.unsqueeze(-1).unsqueeze(-1);
100 } else {
101 // create a fake W dimension so we can use NHWC
102 qx = qx.unsqueeze(-1);
103 }
104
105 auto oSizes = qx.sizes();
106 auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
107 Tensor qy = at::_empty_affine_quantized(
108 oSizes,
109 at::device(kCPU)
110 .dtype(qx_nhwc.scalar_type())
111 .memory_format(MemoryFormat::ChannelsLast),
112 output_scale,
113 output_zero_point,
114 std::nullopt);
115
116 compute_fused_params(
117 C,
118 weight_data,
119 bias_data,
120 mean_data,
121 var_data,
122 eps,
123 qx.q_scale(),
124 output_scale,
125 alpha_data,
126 beta_data);
127 if (ReluFused) {
128 qbatch_norm_relu_stub(
129 qx.device().type(),
130 N,
131 C,
132 H,
133 qx.q_zero_point(),
134 output_zero_point,
135 qx_nhwc,
136 alpha,
137 beta,
138 qy);
139 } else {
140 qbatch_norm_stub(
141 qx.device().type(),
142 N,
143 C,
144 H,
145 qx.q_zero_point(),
146 output_zero_point,
147 qx_nhwc,
148 alpha,
149 beta,
150 qy);
151 }
152 // Remove the fake dimension, and go back to contiguous format
153 // (since there is no 4th channel). Note, this has a performance
154 // cost.
155 Tensor result = qy.contiguous(MemoryFormat::Contiguous).squeeze(-1);
156 if (ndim == 2) {
157 result = result.squeeze(-1);
158 }
159 return result;
160 }
161
162 template <bool ReluFused>
q_batch_norm2d_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)163 Tensor q_batch_norm2d_impl(
164 Tensor qx,
165 std::optional<Tensor> mb_weight,
166 std::optional<Tensor> mb_bias,
167 Tensor mean,
168 Tensor var,
169 double eps,
170 double output_scale,
171 int64_t output_zero_point) {
172
173 TORCH_CHECK(mb_weight.has_value(), "Weight must be provided");
174 TORCH_CHECK(mb_bias.has_value(), "Bias must be provided");
175 const auto& weight = *mb_weight;
176 const auto& bias = *mb_bias;
177
178 if (qx.numel() == 0) {
179 auto out = qx.clone();
180 return out;
181 }
182 int64_t ndim = qx.dim();
183 TORCH_CHECK(ndim == 4, "Expecting the input tensor of rank 4.");
184 const int64_t N = qx.size(0);
185 const int64_t C = qx.size(1);
186 const int64_t H = qx.size(2);
187 const int64_t W = qx.size(3);
188
189 TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
190 TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
191
192 const float* weight_data = weight.template const_data_ptr<float>();
193 const float* bias_data = bias.template const_data_ptr<float>();
194
195 TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
196 TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
197
198 Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
199 Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
200 float* alpha_data = alpha.mutable_data_ptr<float>();
201 float* beta_data = beta.data_ptr<float>();
202
203 const float* mean_data = mean.template const_data_ptr<float>();
204 const float* var_data = var.template const_data_ptr<float>();
205
206 auto oSizes = qx.sizes();
207 auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
208 Tensor qy = at::_empty_affine_quantized(
209 oSizes,
210 at::device(kCPU)
211 .dtype(qx_nhwc.scalar_type())
212 .memory_format(MemoryFormat::ChannelsLast),
213 output_scale,
214 output_zero_point,
215 std::nullopt);
216
217 compute_fused_params(
218 C,
219 weight_data,
220 bias_data,
221 mean_data,
222 var_data,
223 eps,
224 qx.q_scale(),
225 output_scale,
226 alpha_data,
227 beta_data);
228 if (ReluFused) {
229 qbatch_norm_relu_stub(
230 qx.device().type(),
231 N,
232 C,
233 H * W,
234 qx.q_zero_point(),
235 output_zero_point,
236 qx_nhwc,
237 alpha,
238 beta,
239 qy);
240 } else {
241 qbatch_norm_stub(
242 qx.device().type(),
243 N,
244 C,
245 H * W,
246 qx.q_zero_point(),
247 output_zero_point,
248 qx_nhwc,
249 alpha,
250 beta,
251 qy);
252 }
253 return qy;
254 }
255
256 template <bool ReluFused>
q_batch_norm3d_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)257 Tensor q_batch_norm3d_impl(
258 Tensor qx,
259 std::optional<Tensor> mb_weight,
260 std::optional<Tensor> mb_bias,
261 Tensor mean,
262 Tensor var,
263 double eps,
264 double output_scale,
265 int64_t output_zero_point) {
266
267 TORCH_CHECK(mb_weight.has_value(), "Weight must be provided")
268 TORCH_CHECK(mb_bias.has_value(), "Bias must be provided")
269
270 const auto& weight = *mb_weight;
271 const auto& bias = *mb_bias;
272
273 if (qx.numel() == 0) {
274 auto out = qx.clone();
275 return out;
276 }
277 int64_t ndim = qx.dim();
278 TORCH_CHECK(ndim == 5, "Expecting the input tensor of rank 5.");
279 const int64_t N = qx.size(0);
280 const int64_t C = qx.size(1);
281 const int64_t D = qx.size(2);
282 const int64_t H = qx.size(3);
283 const int64_t W = qx.size(4);
284
285 TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
286 TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
287
288 const float* weight_data = weight.template const_data_ptr<float>();
289 const float* bias_data = bias.template const_data_ptr<float>();
290
291 TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
292 TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
293
294 Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
295 Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
296 float* alpha_data = alpha.mutable_data_ptr<float>();
297 float* beta_data = beta.data_ptr<float>();
298
299 const float* mean_data = mean.template const_data_ptr<float>();
300 const float* var_data = var.template const_data_ptr<float>();
301
302 auto oSizes = qx.sizes();
303 auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast3d);
304 Tensor qy = at::_empty_affine_quantized(
305 oSizes,
306 at::device(kCPU)
307 .dtype(qx_nhwc.scalar_type())
308 .memory_format(MemoryFormat::ChannelsLast3d),
309 output_scale,
310 output_zero_point,
311 std::nullopt);
312
313 compute_fused_params(
314 C,
315 weight_data,
316 bias_data,
317 mean_data,
318 var_data,
319 eps,
320 qx.q_scale(),
321 output_scale,
322 alpha_data,
323 beta_data);
324
325 if (ReluFused) {
326 qbatch_norm_relu_stub(
327 qx.device().type(),
328 N,
329 C,
330 D * H * W,
331 qx.q_zero_point(),
332 output_zero_point,
333 qx_nhwc,
334 alpha,
335 beta,
336 qy);
337 } else {
338 qbatch_norm_stub(
339 qx.device().type(),
340 N,
341 C,
342 D * H * W,
343 qx.q_zero_point(),
344 output_zero_point,
345 qx_nhwc,
346 alpha,
347 beta,
348 qy);
349 }
350 return qy;
351 }
352
353 template <bool ReluFused>
q_batch_norm_impl(Tensor qx,std::optional<Tensor> mb_weight,std::optional<Tensor> mb_bias,Tensor mean,Tensor var,double eps,double output_scale,int64_t output_zero_point)354 Tensor q_batch_norm_impl(
355 Tensor qx,
356 std::optional<Tensor> mb_weight,
357 std::optional<Tensor> mb_bias,
358 Tensor mean,
359 Tensor var,
360 double eps,
361 double output_scale,
362 int64_t output_zero_point) {
363 Tensor qy;
364 int64_t dim = qx.dim();
365 if (dim == 2 || dim == 3) {
366 qy = q_batch_norm1d_impl<ReluFused>(
367 qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
368 } else if (dim == 4) {
369 qy = q_batch_norm2d_impl<ReluFused>(
370 qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
371 } else if (dim == 5) {
372 qy = q_batch_norm3d_impl<ReluFused>(
373 qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
374 } else {
375 TORCH_CHECK(false, "quantized::batch_norm only support 2d, 3d, 4d or 5d inputs.");
376 }
377 return qy;
378 }
379
380 } // namespace
381
quantized_batch_norm(const Tensor & qx,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const Tensor & mean,const Tensor & var,double eps,double output_scale,int64_t output_zero_point)382 Tensor quantized_batch_norm(
383 const Tensor& qx, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */,
384 const Tensor& mean /* optional */,
385 const Tensor& var /* optional */,
386 double eps,
387 double output_scale,
388 int64_t output_zero_point) {
389 // See [Note: hacky wrapper removal for optional tensor]
390 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
391 const Tensor& weight = *weight_maybe_owned;
392 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
393
394 Tensor qy;
395 // TODO: this should arguably support 3d as well
396 qy = q_batch_norm2d_impl<false>(
397 qx,
398 weight.defined() ? std::make_optional(weight) : std::nullopt,
399 bias.defined() ? std::make_optional(bias) : std::nullopt,
400 mean, var, eps, output_scale, output_zero_point);
401 return qy;
402 }
403
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)404 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
405 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm"), TORCH_FN(q_batch_norm_impl<false>));
406 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm_relu"), TORCH_FN(q_batch_norm_impl<true>));
407 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm1d"), TORCH_FN(q_batch_norm1d_impl<false>));
408 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm1d_relu"), TORCH_FN(q_batch_norm1d_impl<true>));
409 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm2d"), TORCH_FN(q_batch_norm2d_impl<false>));
410 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm2d_relu"), TORCH_FN(q_batch_norm2d_impl<true>));
411 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm3d"), TORCH_FN(q_batch_norm3d_impl<false>));
412 m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm3d_relu"), TORCH_FN(q_batch_norm3d_impl<true>));
413 }
414
415 } // namespace native
416 } // namespace at
417