xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quant_type.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import enum
2
3
4__all__ = [
5    "QuantType",
6]
7
8
9# Quantization type (dynamic quantization, static quantization).
10# Should match the c++ enum in quantization_type.h
11class QuantType(enum.IntEnum):
12    DYNAMIC = 0
13    STATIC = 1
14    QAT = 2
15    WEIGHT_ONLY = 3
16
17
18_quant_type_to_str = {
19    QuantType.STATIC: "static",
20    QuantType.DYNAMIC: "dynamic",
21    QuantType.QAT: "qat",
22    QuantType.WEIGHT_ONLY: "weight_only",
23}
24
25
26# TODO: make this private
27def _get_quant_type_to_str(quant_type: QuantType) -> str:
28    return _quant_type_to_str[quant_type]
29
30
31def _quant_type_from_str(name: str) -> QuantType:
32    for quant_type, s in _quant_type_to_str.items():
33        if name == s:
34            return quant_type
35    raise ValueError(f"Unknown QuantType name '{name}'")
36