xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/common/network_executor.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import os
5from typing import List, Tuple
6
7import pyarmnn as ann
8import numpy as np
9
10class ArmnnNetworkExecutor:
11
12    def __init__(self, model_file: str, backends: list):
13        """
14        Creates an inference executor for a given network and a list of backends.
15
16        Args:
17            model_file: User-specified model file.
18            backends: List of backends to optimize network.
19        """
20        self.model_file = model_file
21        self.backends = backends
22        self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network()
23        self.output_tensors = ann.make_output_tensors(self.output_binding_info)
24
25    def run(self, input_data_list: list) -> List[np.ndarray]:
26        """
27        Creates input tensors from input data and executes inference with the loaded network.
28
29        Args:
30            input_data_list: List of input frames.
31
32        Returns:
33            list: Inference results as a list of ndarrays.
34        """
35        input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list)
36        self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors)
37        output = ann.workload_tensors_to_ndarray(self.output_tensors)
38
39        return output
40
41    def create_network(self):
42        """
43        Creates a network based on the model file and a list of backends.
44
45        Returns:
46            net_id: Unique ID of the network to run.
47            runtime: Runtime context for executing inference.
48            input_binding_info: Contains essential information about the model input.
49            output_binding_info: Used to map output tensor and its memory.
50        """
51        if not os.path.exists(self.model_file):
52            raise FileNotFoundError(f'Model file not found for: {self.model_file}')
53
54        _, ext = os.path.splitext(self.model_file)
55        if ext == '.tflite':
56            parser = ann.ITfLiteParser()
57        else:
58            raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
59
60        network = parser.CreateNetworkFromBinaryFile(self.model_file)
61
62        # Specify backends to optimize network
63        preferred_backends = []
64        for b in self.backends:
65            preferred_backends.append(ann.BackendId(b))
66
67        # Select appropriate device context and optimize the network for that device
68        options = ann.CreationOptions()
69        runtime = ann.IRuntime(options)
70        opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
71                                             ann.OptimizerOptions())
72        print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n'
73              f'Optimization warnings: {messages}')
74
75        # Load the optimized network onto the Runtime device
76        net_id, _ = runtime.LoadNetwork(opt_network)
77
78        # Get input and output binding information
79        graph_id = parser.GetSubgraphCount() - 1
80        input_names = parser.GetSubgraphInputTensorNames(graph_id)
81        input_binding_info = []
82        for input_name in input_names:
83            in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name)
84            input_binding_info.append(in_bind_info)
85        output_names = parser.GetSubgraphOutputTensorNames(graph_id)
86        output_binding_info = []
87        for output_name in output_names:
88            out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
89            output_binding_info.append(out_bind_info)
90        return net_id, runtime, input_binding_info, output_binding_info
91
92    def get_data_type(self):
93        """
94        Get the input data type of the initiated network.
95
96        Returns:
97            numpy data type or None if doesn't exist in the if condition.
98        """
99        if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32:
100            return np.float32
101        elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8:
102            return np.uint8
103        elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8:
104            return np.int8
105        else:
106            return None
107
108    def get_shape(self):
109        """
110        Get the input shape of the initiated network.
111
112        Returns:
113            tuple: The Shape of the network input.
114        """
115        return tuple(self.input_binding_info[0][1].GetShape())
116
117    def get_input_quantization_scale(self, idx):
118        """
119        Get the input quantization scale of the initiated network.
120
121        Returns:
122            The quantization scale  of the network input.
123        """
124        return self.input_binding_info[idx][1].GetQuantizationScale()
125
126    def get_input_quantization_offset(self, idx):
127        """
128        Get the input quantization offset of the initiated network.
129
130        Returns:
131            The quantization offset of the network input.
132        """
133        return self.input_binding_info[idx][1].GetQuantizationOffset()
134
135    def is_output_quantized(self, idx):
136        """
137        Get True/False if output tensor is quantized or not respectively.
138
139        Returns:
140            True if output is quantized and False otherwise.
141        """
142        return self.output_binding_info[idx][1].IsQuantized()
143
144    def get_output_quantization_scale(self, idx):
145        """
146        Get the output quantization offset of the initiated network.
147
148        Returns:
149            The quantization offset of the network output.
150        """
151        return self.output_binding_info[idx][1].GetQuantizationScale()
152
153    def get_output_quantization_offset(self, idx):
154        """
155        Get the output quantization offset of the initiated network.
156
157        Returns:
158            The quantization offset of the network output.
159        """
160        return self.output_binding_info[idx][1].GetQuantizationOffset()
161
162