xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_tensor_conversion.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3import os
4
5import pytest
6import pyarmnn as ann
7import numpy as np
8
9
10@pytest.fixture(scope="function")
11def get_tensor_info_input(shared_data_folder):
12    """
13    Sample input tensor information.
14    """
15    parser = ann.ITfLiteParser()
16    parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite'))
17    graph_id = 0
18
19    input_binding_info = [parser.GetNetworkInputBindingInfo(graph_id, 'input_1')]
20
21    yield input_binding_info
22
23
24@pytest.fixture(scope="function")
25def get_tensor_info_output(shared_data_folder):
26    """
27    Sample output tensor information.
28    """
29    parser = ann.ITfLiteParser()
30    parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite'))
31    graph_id = 0
32
33    output_names = parser.GetSubgraphOutputTensorNames(graph_id)
34    outputs_binding_info = []
35
36    for output_name in output_names:
37        outputs_binding_info.append(parser.GetNetworkOutputBindingInfo(graph_id, output_name))
38
39    yield outputs_binding_info
40
41
42def test_make_input_tensors(get_tensor_info_input):
43    input_tensor_info = get_tensor_info_input
44    input_data = []
45
46    for tensor_id, tensor_info in input_tensor_info:
47        input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.uint8))
48
49    input_tensors = ann.make_input_tensors(input_tensor_info, input_data)
50    assert len(input_tensors) == 1
51
52    for tensor, tensor_info in zip(input_tensors, input_tensor_info):
53        # Because we created ConstTensor function, we cannot check type directly.
54        assert type(tensor[1]).__name__ == 'ConstTensor'
55        assert str(tensor[1].GetInfo()) == str(tensor_info[1])
56
57
58def test_make_output_tensors(get_tensor_info_output):
59    output_binding_info = get_tensor_info_output
60
61    output_tensors = ann.make_output_tensors(output_binding_info)
62    assert len(output_tensors) == 1
63
64    for tensor, tensor_info in zip(output_tensors, output_binding_info):
65        assert type(tensor[1]) == ann.Tensor
66        assert str(tensor[1].GetInfo()) == str(tensor_info[1])
67
68
69def test_workload_tensors_to_ndarray(get_tensor_info_output):
70    # Check shape and size of output from workload_tensors_to_ndarray matches expected.
71    output_binding_info = get_tensor_info_output
72    output_tensors = ann.make_output_tensors(output_binding_info)
73
74    data = ann.workload_tensors_to_ndarray(output_tensors)
75
76    for i in range(0, len(output_tensors)):
77        assert data[i].shape == tuple(output_tensors[i][1].GetShape())
78        assert data[i].size == output_tensors[i][1].GetNumElements()
79
80
81def test_make_input_tensors_fp16(get_tensor_info_input):
82    # Check ConstTensor with float16
83    input_tensor_info = get_tensor_info_input
84    input_data = []
85
86    for tensor_id, tensor_info in input_tensor_info:
87        input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.float16))
88        tensor_info.SetDataType(ann.DataType_Float16)  # set datatype to float16
89
90    input_tensors = ann.make_input_tensors(input_tensor_info, input_data)
91    assert len(input_tensors) == 1
92
93    for tensor, tensor_info in zip(input_tensors, input_tensor_info):
94        # Because we created ConstTensor function, we cannot check type directly.
95        assert type(tensor[1]).__name__ == 'ConstTensor'
96        assert str(tensor[1].GetInfo()) == str(tensor_info[1])
97        assert tensor[1].GetDataType() == ann.DataType_Float16
98        assert tensor[1].GetNumElements() == 28*28*1
99        assert tensor[1].GetNumBytes() == (28*28*1)*2  # check each element is two byte
100