1# Copyright © 2020 Arm Ltd. All rights reserved. 2# SPDX-License-Identifier: MIT 3import os 4 5import pytest 6import pyarmnn as ann 7import numpy as np 8 9 10@pytest.fixture(scope="function") 11def get_tensor_info_input(shared_data_folder): 12 """ 13 Sample input tensor information. 14 """ 15 parser = ann.ITfLiteParser() 16 parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite')) 17 graph_id = 0 18 19 input_binding_info = [parser.GetNetworkInputBindingInfo(graph_id, 'input_1')] 20 21 yield input_binding_info 22 23 24@pytest.fixture(scope="function") 25def get_tensor_info_output(shared_data_folder): 26 """ 27 Sample output tensor information. 28 """ 29 parser = ann.ITfLiteParser() 30 parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite')) 31 graph_id = 0 32 33 output_names = parser.GetSubgraphOutputTensorNames(graph_id) 34 outputs_binding_info = [] 35 36 for output_name in output_names: 37 outputs_binding_info.append(parser.GetNetworkOutputBindingInfo(graph_id, output_name)) 38 39 yield outputs_binding_info 40 41 42def test_make_input_tensors(get_tensor_info_input): 43 input_tensor_info = get_tensor_info_input 44 input_data = [] 45 46 for tensor_id, tensor_info in input_tensor_info: 47 input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.uint8)) 48 49 input_tensors = ann.make_input_tensors(input_tensor_info, input_data) 50 assert len(input_tensors) == 1 51 52 for tensor, tensor_info in zip(input_tensors, input_tensor_info): 53 # Because we created ConstTensor function, we cannot check type directly. 54 assert type(tensor[1]).__name__ == 'ConstTensor' 55 assert str(tensor[1].GetInfo()) == str(tensor_info[1]) 56 57 58def test_make_output_tensors(get_tensor_info_output): 59 output_binding_info = get_tensor_info_output 60 61 output_tensors = ann.make_output_tensors(output_binding_info) 62 assert len(output_tensors) == 1 63 64 for tensor, tensor_info in zip(output_tensors, output_binding_info): 65 assert type(tensor[1]) == ann.Tensor 66 assert str(tensor[1].GetInfo()) == str(tensor_info[1]) 67 68 69def test_workload_tensors_to_ndarray(get_tensor_info_output): 70 # Check shape and size of output from workload_tensors_to_ndarray matches expected. 71 output_binding_info = get_tensor_info_output 72 output_tensors = ann.make_output_tensors(output_binding_info) 73 74 data = ann.workload_tensors_to_ndarray(output_tensors) 75 76 for i in range(0, len(output_tensors)): 77 assert data[i].shape == tuple(output_tensors[i][1].GetShape()) 78 assert data[i].size == output_tensors[i][1].GetNumElements() 79 80 81def test_make_input_tensors_fp16(get_tensor_info_input): 82 # Check ConstTensor with float16 83 input_tensor_info = get_tensor_info_input 84 input_data = [] 85 86 for tensor_id, tensor_info in input_tensor_info: 87 input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.float16)) 88 tensor_info.SetDataType(ann.DataType_Float16) # set datatype to float16 89 90 input_tensors = ann.make_input_tensors(input_tensor_info, input_data) 91 assert len(input_tensors) == 1 92 93 for tensor, tensor_info in zip(input_tensors, input_tensor_info): 94 # Because we created ConstTensor function, we cannot check type directly. 95 assert type(tensor[1]).__name__ == 'ConstTensor' 96 assert str(tensor[1].GetInfo()) == str(tensor_info[1]) 97 assert tensor[1].GetDataType() == ann.DataType_Float16 98 assert tensor[1].GetNumElements() == 28*28*1 99 assert tensor[1].GetNumBytes() == (28*28*1)*2 # check each element is two byte 100