xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/finalize.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/passes/quantization/quantization_type.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 /** \brief Backend specific pass to fuse dequantize - op - quantize calls
11  * as quantized_op calls.
12  *
13  * Right now this is a fusion for fbgemm backend and only works for quantized
14  * conv op, we'll extend to more ops and more backends in the future.
15  *
16  * Currently supported fusion:
17  * q(conv2d(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_conv2d(prepack(to_nhwc(a)),
18  *                                                          prepack(to_nhwc(w)),
19  *                                                          prepack(to_nhwc(b))))
20  *
21  * q(linear(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_linear(prepack(to_nhwc(a)),
22  *                                                          prepack(to_nhwc(w)),
23  *                                                          prepack(to_nhwc(b))))
24  *
25  * \param graph the graph we want to apply fusion
26  */
27 TORCH_API void QuantFusion(
28     std::shared_ptr<Graph>& graph,
29     QuantType quant_type = QuantType::STATIC);
30 
31 /** \brief Insert prepack and unpack function in graph
32  *  We want add pack/unpack functions for quantized weight because later we want
33  * to fold the packed weight as an attribute of the module, in order to reduce
34  * the cost of packing the weight on the fly in quantized models.
35  *
36  *  Each quantized op has it's corresponding prepack/unpack function,
37  *  right now, we only need to do prepack/unpack for quantized::linear
38  * and quantized::conv2d.
39  */
40 TORCH_API void InsertPrepackUnpack(std::shared_ptr<Graph>& graph);
41 
42 /** \brief Insert pack and unpack function in all graphs
43  *   of module
44  *
45  *   Go through graphs of all the methods of all child modules
46  *   and call InsertPrepackUnpack on the graph.
47  */
48 TORCH_API void InsertPrepackUnpack(Module& module);
49 
50 TORCH_API script::Module Finalize(
51     script::Module& module,
52     QuantType quant_type = QuantType::STATIC,
53     const std::vector<std::string>& preserved_attrs =
54         std::vector<std::string>());
55 
56 TORCH_API void FoldQuantizedPrepackingOps(Module& module);
57 
58 TORCH_API Module FinalizeOnDevicePTQ(
59     Module& module,
60     QuantType quant_type,
61     const std::string& method_name);
62 } // namespace jit
63 } // namespace torch
64