1 #pragma once 2 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <torch/csrc/jit/passes/quantization/quantization_type.h> 6 7 namespace torch { 8 namespace jit { 9 10 /** \brief Backend specific pass to fuse dequantize - op - quantize calls 11 * as quantized_op calls. 12 * 13 * Right now this is a fusion for fbgemm backend and only works for quantized 14 * conv op, we'll extend to more ops and more backends in the future. 15 * 16 * Currently supported fusion: 17 * q(conv2d(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_conv2d(prepack(to_nhwc(a)), 18 * prepack(to_nhwc(w)), 19 * prepack(to_nhwc(b)))) 20 * 21 * q(linear(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_linear(prepack(to_nhwc(a)), 22 * prepack(to_nhwc(w)), 23 * prepack(to_nhwc(b)))) 24 * 25 * \param graph the graph we want to apply fusion 26 */ 27 TORCH_API void QuantFusion( 28 std::shared_ptr<Graph>& graph, 29 QuantType quant_type = QuantType::STATIC); 30 31 /** \brief Insert prepack and unpack function in graph 32 * We want add pack/unpack functions for quantized weight because later we want 33 * to fold the packed weight as an attribute of the module, in order to reduce 34 * the cost of packing the weight on the fly in quantized models. 35 * 36 * Each quantized op has it's corresponding prepack/unpack function, 37 * right now, we only need to do prepack/unpack for quantized::linear 38 * and quantized::conv2d. 39 */ 40 TORCH_API void InsertPrepackUnpack(std::shared_ptr<Graph>& graph); 41 42 /** \brief Insert pack and unpack function in all graphs 43 * of module 44 * 45 * Go through graphs of all the methods of all child modules 46 * and call InsertPrepackUnpack on the graph. 47 */ 48 TORCH_API void InsertPrepackUnpack(Module& module); 49 50 TORCH_API script::Module Finalize( 51 script::Module& module, 52 QuantType quant_type = QuantType::STATIC, 53 const std::vector<std::string>& preserved_attrs = 54 std::vector<std::string>()); 55 56 TORCH_API void FoldQuantizedPrepackingOps(Module& module); 57 58 TORCH_API Module FinalizeOnDevicePTQ( 59 Module& module, 60 QuantType quant_type, 61 const std::string& method_name); 62 } // namespace jit 63 } // namespace torch 64