1# Copyright © 2020 Arm Ltd. All rights reserved. 2# SPDX-License-Identifier: MIT 3import pytest 4import numpy as np 5 6import pyarmnn as ann 7 8# import generated so we can test for Dequantize_* and Quantize_* 9# functions not available in the public API. 10import pyarmnn._generated.pyarmnn as gen_ann 11 12 13@pytest.mark.parametrize('method', ['Quantize_int8_t', 14 'Quantize_uint8_t', 15 'Quantize_int16_t', 16 'Quantize_int32_t', 17 'Dequantize_int8_t', 18 'Dequantize_uint8_t', 19 'Dequantize_int16_t', 20 'Dequantize_int32_t']) 21def test_quantize_exists(method): 22 assert method in dir(gen_ann) and callable(getattr(gen_ann, method)) 23 24 25@pytest.mark.parametrize('dt, min, max', [('uint8', 0, 255), 26 ('int8', -128, 127), 27 ('int16', -32768, 32767), 28 ('int32', -2147483648, 2147483647)]) 29def test_quantize_uint8_output(dt, min, max): 30 result = ann.quantize(3.3274056911468506, 0.02620004490017891, 128, dt) 31 assert type(result) is int and min <= result <= max 32 33 34@pytest.mark.parametrize('dt', ['uint8', 35 'int8', 36 'int16', 37 'int32']) 38def test_dequantize_uint8_output(dt): 39 result = ann.dequantize(3, 0.02620004490017891, 128, dt) 40 assert type(result) is float 41 42 43def test_quantize_unsupported_dtype(): 44 with pytest.raises(ValueError) as err: 45 ann.quantize(3.3274056911468506, 0.02620004490017891, 128, 'uint16') 46 47 assert 'Unexpected target datatype uint16 given.' in str(err.value) 48 49 50def test_dequantize_unsupported_dtype(): 51 with pytest.raises(ValueError) as err: 52 ann.dequantize(3, 0.02620004490017891, 128, 'uint16') 53 54 assert 'Unexpected value datatype uint16 given.' in str(err.value) 55 56 57def test_dequantize_value_range(): 58 with pytest.raises(ValueError) as err: 59 ann.dequantize(-1, 0.02620004490017891, 128, 'uint8') 60 61 assert 'Value is not within range of the given datatype uint8' in str(err.value) 62 63 64@pytest.mark.parametrize('dt, data', [('uint8', np.uint8(255)), 65 ('int8', np.int8(127)), 66 ('int16', np.int16(32767)), 67 ('int32', np.int32(2147483647)), 68 69 ('uint8', np.int8(127)), 70 ('uint8', np.int16(255)), 71 ('uint8', np.int32(255)), 72 73 ('int8', np.uint8(127)), 74 ('int8', np.int16(127)), 75 ('int8', np.int32(127)), 76 77 ('int16', np.int8(127)), 78 ('int16', np.uint8(255)), 79 ('int16', np.int32(32767)), 80 81 ('int32', np.uint8(255)), 82 ('int16', np.int8(127)), 83 ('int32', np.int16(32767)) 84 85 ]) 86def test_dequantize_numpy_dt(dt, data): 87 result = ann.dequantize(data, 1, 0, dt) 88 89 assert type(result) is float 90 91 assert np.float32(data) == result 92