xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qconv_unpack_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <tuple>
2 #include <vector>
3 
4 #include <ATen/ATen.h>
5 #include <torch/library.h>
6 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
7 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
8 #include <ATen/native/quantized/cpu/OnednnUtils.h>
9 #include <ATen/native/quantized/cpu/QuantUtils.h>
10 #include <ATen/native/quantized/PackedParams.h>
11 
12 #ifdef USE_FBGEMM
13 template <int kSpatialDim>
14 std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeight<
unpack()15     kSpatialDim>::unpack() {
16   auto* packed_weights_p = w.get();
17   // output channels
18   const int output_channels = packed_weights_p->outputChannels();
19   const int input_channels = packed_weights_p->inputChannels();
20   const int groups = packed_weights_p->groups();
21 
22   const int kernel_d = kSpatialDim == 2 ? 1 : kernel[0];
23   // R (kernel height)
24   const int kernel_h = kernel[kSpatialDim - 2];
25   // S (kernel width)
26   const int kernel_w = kernel[kSpatialDim - 1];
27 
28   const int C_per_G = input_channels / groups;
29 
30   // Tensor for unpacked weights
31   // Unpacked format would be physical KRS(C/G) but logical KCRS (channels
32   // first) because that's how
33   // ChannelsLast3d is not available now.FBGEMM stores the weights
34   // TODO: Unify 2d and 3d when ChannelsLast3d is ready.
35   at::Tensor unpacked_weights;
36   if (q_scheme == c10::kPerTensorAffine) {
37     unpacked_weights = kSpatialDim == 2
38         ? at::_empty_affine_quantized(
39               {output_channels, C_per_G, kernel_h, kernel_w},
40               device(c10::kCPU)
41                   .dtype(c10::kQInt8)
42                   .memory_format(c10::MemoryFormat::ChannelsLast),
43               w_scale[0],
44               w_zp[0],
45               std::nullopt)
46         : at::native::fbgemm_utils::
47               MakeEmptyAffineQuantizedChannelsLast3dTensor(
48                   output_channels,
49                   C_per_G,
50                   kernel_d,
51                   kernel_h,
52                   kernel_w,
53                   device(c10::kCPU).dtype(c10::kQInt8),
54                   w_scale[0],
55                   w_zp[0]);
56   } else if (q_scheme == c10::kPerChannelAffine) {
57     TORCH_CHECK(
58         !transpose(),
59         "Per Channel Quantization is currently disabled for transposed conv");
60     auto scales = at::from_blob(
61         w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat));
62     auto zero_points = at::from_blob(
63         w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kInt));
64     unpacked_weights = kSpatialDim == 2
65         ? at::_empty_per_channel_affine_quantized(
66               {output_channels, C_per_G, kernel_h, kernel_w},
67               scales.toType(c10::kDouble),
68               zero_points.toType(c10::kLong),
69               0, /* The output channel axis is 0 */
70               device(c10::kCPU).dtype(c10::kQInt8),
71               c10::MemoryFormat::ChannelsLast)
72         : at::native::fbgemm_utils::
73               MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor(
74                   output_channels,
75                   C_per_G,
76                   kernel_d,
77                   kernel_h,
78                   kernel_w,
79                   device(c10::kCPU).dtype(c10::kQInt8),
80                   scales.toType(c10::kDouble),
81                   zero_points.toType(c10::kLong));
82   } else {
83     TORCH_CHECK(false, "Unsupported qscheme: ", toString(q_scheme));
84   }
85   int8_t* unpacked_weights_p =
86       reinterpret_cast<int8_t*>(unpacked_weights.data_ptr<c10::qint8>());
87   packed_weights_p->unpack(unpacked_weights_p);
88   if(transpose()){
89     unpacked_weights =
90         at::native::fbgemm_utils::TransposeConvTensorUnpackConversion<
91             kSpatialDim>(unpacked_weights, groups);
92   }
93   return std::tuple<at::Tensor, std::optional<at::Tensor>>(
94       unpacked_weights, bias);
95 }
96 
97 template std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeight<
98     2>::unpack();
99 template std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeight<
100     3>::unpack();
101 #endif // USE_FBGEMM
102 
103 #ifdef USE_PYTORCH_QNNPACK
104 template <int kSpatialDim>
105 std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeightsQnnp<
unpack()106     kSpatialDim>::unpack() {
107   TORCH_CHECK(
108       kSpatialDim == 2,
109       "QNNPACK only supports conv2d_unpack right "
110       "now.");
111   TORCH_CHECK(
112         orig_weight.defined(),
113         "Cannot unpack weights. "
114         "Call at::globalContext()::setReleaseOriginalWeights(false) before packing or loading to enable unpacking.");
115   return std::tuple<at::Tensor, std::optional<at::Tensor>>(orig_weight, bias);
116 }
117 
118 template std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeightsQnnp<
119     2>::unpack();
120 template std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeightsQnnp<
121     3>::unpack();
122 #endif // USE_PYTORCH_QNNPACK
123 
124 #if AT_MKLDNN_ENABLED()
125 template <int kSpatialDim>
126 std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeightsOnednn<
unpack()127     kSpatialDim>::unpack() {
128   return std::tuple<at::Tensor, std::optional<at::Tensor>>(
129       orig_weight_.clone(), orig_bias_);
130 }
131 
132 template std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeightsOnednn<
133     2>::unpack();
134 template std::tuple<at::Tensor, std::optional<at::Tensor>> PackedConvWeightsOnednn<
135     3>::unpack();
136 #endif // #if AT_MKLDNN_ENABLED()
137