xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_quantize_and_dequantize.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3import pytest
4import numpy as np
5
6import pyarmnn as ann
7
8# import generated so we can test for Dequantize_* and Quantize_*
9# functions not available in the public API.
10import pyarmnn._generated.pyarmnn as gen_ann
11
12
13@pytest.mark.parametrize('method', ['Quantize_int8_t',
14                                    'Quantize_uint8_t',
15                                    'Quantize_int16_t',
16                                    'Quantize_int32_t',
17                                    'Dequantize_int8_t',
18                                    'Dequantize_uint8_t',
19                                    'Dequantize_int16_t',
20                                    'Dequantize_int32_t'])
21def test_quantize_exists(method):
22    assert method in dir(gen_ann) and callable(getattr(gen_ann, method))
23
24
25@pytest.mark.parametrize('dt, min, max', [('uint8', 0, 255),
26                                          ('int8', -128, 127),
27                                          ('int16', -32768, 32767),
28                                          ('int32', -2147483648, 2147483647)])
29def test_quantize_uint8_output(dt, min, max):
30    result = ann.quantize(3.3274056911468506, 0.02620004490017891, 128, dt)
31    assert type(result) is int and min <= result <= max
32
33
34@pytest.mark.parametrize('dt', ['uint8',
35                                'int8',
36                                'int16',
37                                'int32'])
38def test_dequantize_uint8_output(dt):
39    result = ann.dequantize(3, 0.02620004490017891, 128, dt)
40    assert type(result) is float
41
42
43def test_quantize_unsupported_dtype():
44    with pytest.raises(ValueError) as err:
45        ann.quantize(3.3274056911468506, 0.02620004490017891, 128, 'uint16')
46
47    assert 'Unexpected target datatype uint16 given.' in str(err.value)
48
49
50def test_dequantize_unsupported_dtype():
51    with pytest.raises(ValueError) as err:
52        ann.dequantize(3, 0.02620004490017891, 128, 'uint16')
53
54    assert 'Unexpected value datatype uint16 given.' in str(err.value)
55
56
57def test_dequantize_value_range():
58    with pytest.raises(ValueError) as err:
59        ann.dequantize(-1, 0.02620004490017891, 128, 'uint8')
60
61    assert 'Value is not within range of the given datatype uint8' in str(err.value)
62
63
64@pytest.mark.parametrize('dt, data', [('uint8', np.uint8(255)),
65                                      ('int8',  np.int8(127)),
66                                      ('int16', np.int16(32767)),
67                                      ('int32', np.int32(2147483647)),
68
69                                      ('uint8', np.int8(127)),
70                                      ('uint8', np.int16(255)),
71                                      ('uint8', np.int32(255)),
72
73                                      ('int8', np.uint8(127)),
74                                      ('int8', np.int16(127)),
75                                      ('int8', np.int32(127)),
76
77                                      ('int16', np.int8(127)),
78                                      ('int16', np.uint8(255)),
79                                      ('int16', np.int32(32767)),
80
81                                      ('int32', np.uint8(255)),
82                                      ('int16', np.int8(127)),
83                                      ('int32', np.int16(32767))
84
85                                      ])
86def test_dequantize_numpy_dt(dt, data):
87    result = ann.dequantize(data, 1, 0, dt)
88
89    assert type(result) is float
90
91    assert np.float32(data) == result
92