xref: /aosp_15_r20/external/pytorch/test/quantization/ao_migration/test_quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # Owner(s): ["oncall: quantization"]
2 
3 from .common import AOMigrationTestCase
4 
5 
6 class TestAOMigrationQuantization(AOMigrationTestCase):
7     r"""Modules and functions related to the
8     `torch/quantization` migration to `torch/ao/quantization`.
9     """
10 
11     def test_function_import_quantize(self):
12         function_list = [
13             "_convert",
14             "_observer_forward_hook",
15             "_propagate_qconfig_helper",
16             "_remove_activation_post_process",
17             "_remove_qconfig",
18             "_add_observer_",
19             "add_quant_dequant",
20             "convert",
21             "_get_observer_dict",
22             "_get_unique_devices_",
23             "_is_activation_post_process",
24             "prepare",
25             "prepare_qat",
26             "propagate_qconfig_",
27             "quantize",
28             "quantize_dynamic",
29             "quantize_qat",
30             "_register_activation_post_process_hook",
31             "swap_module",
32         ]
33         self._test_function_import("quantize", function_list)
34 
35     def test_function_import_stubs(self):
36         function_list = [
37             "QuantStub",
38             "DeQuantStub",
39             "QuantWrapper",
40         ]
41         self._test_function_import("stubs", function_list)
42 
43     def test_function_import_quantize_jit(self):
44         function_list = [
45             "_check_is_script_module",
46             "_check_forward_method",
47             "script_qconfig",
48             "script_qconfig_dict",
49             "fuse_conv_bn_jit",
50             "_prepare_jit",
51             "prepare_jit",
52             "prepare_dynamic_jit",
53             "_convert_jit",
54             "convert_jit",
55             "convert_dynamic_jit",
56             "_quantize_jit",
57             "quantize_jit",
58             "quantize_dynamic_jit",
59         ]
60         self._test_function_import("quantize_jit", function_list)
61 
62     def test_function_import_fake_quantize(self):
63         function_list = [
64             "_is_per_channel",
65             "_is_per_tensor",
66             "_is_symmetric_quant",
67             "FakeQuantizeBase",
68             "FakeQuantize",
69             "FixedQParamsFakeQuantize",
70             "FusedMovingAvgObsFakeQuantize",
71             "default_fake_quant",
72             "default_weight_fake_quant",
73             "default_fixed_qparams_range_neg1to1_fake_quant",
74             "default_fixed_qparams_range_0to1_fake_quant",
75             "default_per_channel_weight_fake_quant",
76             "default_histogram_fake_quant",
77             "default_fused_act_fake_quant",
78             "default_fused_wt_fake_quant",
79             "default_fused_per_channel_wt_fake_quant",
80             "_is_fake_quant_script_module",
81             "disable_fake_quant",
82             "enable_fake_quant",
83             "disable_observer",
84             "enable_observer",
85         ]
86         self._test_function_import("fake_quantize", function_list)
87 
88     def test_function_import_fuse_modules(self):
89         function_list = [
90             "_fuse_modules",
91             "_get_module",
92             "_set_module",
93             "fuse_conv_bn",
94             "fuse_conv_bn_relu",
95             "fuse_known_modules",
96             "fuse_modules",
97             "get_fuser_method",
98         ]
99         self._test_function_import("fuse_modules", function_list)
100 
101     def test_function_import_quant_type(self):
102         function_list = [
103             "QuantType",
104             "_get_quant_type_to_str",
105         ]
106         self._test_function_import("quant_type", function_list)
107 
108     def test_function_import_observer(self):
109         function_list = [
110             "_PartialWrapper",
111             "_with_args",
112             "_with_callable_args",
113             "ABC",
114             "ObserverBase",
115             "_ObserverBase",
116             "MinMaxObserver",
117             "MovingAverageMinMaxObserver",
118             "PerChannelMinMaxObserver",
119             "MovingAveragePerChannelMinMaxObserver",
120             "HistogramObserver",
121             "PlaceholderObserver",
122             "RecordingObserver",
123             "NoopObserver",
124             "_is_activation_post_process",
125             "_is_per_channel_script_obs_instance",
126             "get_observer_state_dict",
127             "load_observer_state_dict",
128             "default_observer",
129             "default_placeholder_observer",
130             "default_debug_observer",
131             "default_weight_observer",
132             "default_histogram_observer",
133             "default_per_channel_weight_observer",
134             "default_dynamic_quant_observer",
135             "default_float_qparams_observer",
136         ]
137         self._test_function_import("observer", function_list)
138 
139     def test_function_import_qconfig(self):
140         function_list = [
141             "QConfig",
142             "default_qconfig",
143             "default_debug_qconfig",
144             "default_per_channel_qconfig",
145             "QConfigDynamic",
146             "default_dynamic_qconfig",
147             "float16_dynamic_qconfig",
148             "float16_static_qconfig",
149             "per_channel_dynamic_qconfig",
150             "float_qparams_weight_only_qconfig",
151             "default_qat_qconfig",
152             "default_weight_only_qconfig",
153             "default_activation_only_qconfig",
154             "default_qat_qconfig_v2",
155             "get_default_qconfig",
156             "get_default_qat_qconfig",
157             "_assert_valid_qconfig",
158             "QConfigAny",
159             "_add_module_to_qconfig_obs_ctr",
160             "qconfig_equals",
161         ]
162         self._test_function_import("qconfig", function_list)
163 
164     def test_function_import_quantization_mappings(self):
165         function_list = [
166             "no_observer_set",
167             "get_default_static_quant_module_mappings",
168             "get_static_quant_module_class",
169             "get_dynamic_quant_module_class",
170             "get_default_qat_module_mappings",
171             "get_default_dynamic_quant_module_mappings",
172             "get_default_qconfig_propagation_list",
173             "get_default_compare_output_module_list",
174             "get_default_float_to_quantized_operator_mappings",
175             "get_quantized_operator",
176             "_get_special_act_post_process",
177             "_has_special_act_post_process",
178         ]
179         dict_list = [
180             "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
181             "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
182             "DEFAULT_QAT_MODULE_MAPPINGS",
183             "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
184             # "_INCLUDE_QCONFIG_PROPAGATE_LIST",
185             "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
186             "DEFAULT_MODULE_TO_ACT_POST_PROCESS",
187         ]
188         self._test_function_import("quantization_mappings", function_list)
189         self._test_dict_import("quantization_mappings", dict_list)
190 
191     def test_function_import_fuser_method_mappings(self):
192         function_list = [
193             "fuse_conv_bn",
194             "fuse_conv_bn_relu",
195             "fuse_linear_bn",
196             "get_fuser_method",
197         ]
198         dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"]
199         self._test_function_import("fuser_method_mappings", function_list)
200         self._test_dict_import("fuser_method_mappings", dict_list)
201 
202     def test_function_import_utils(self):
203         function_list = [
204             "activation_dtype",
205             "activation_is_int8_quantized",
206             "activation_is_statically_quantized",
207             "calculate_qmin_qmax",
208             "check_min_max_valid",
209             "get_combined_dict",
210             "get_qconfig_dtypes",
211             "get_qparam_dict",
212             "get_quant_type",
213             "get_swapped_custom_module_class",
214             "getattr_from_fqn",
215             "is_per_channel",
216             "is_per_tensor",
217             "weight_dtype",
218             "weight_is_quantized",
219             "weight_is_statically_quantized",
220         ]
221         self._test_function_import("utils", function_list)
222