xref: /aosp_15_r20/external/ComputeLibrary/python/scripts/utils/model_identification.py (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust# Copyright (c) 2021 Arm Limited.
2*c217d954SCole Faust#
3*c217d954SCole Faust# SPDX-License-Identifier: MIT
4*c217d954SCole Faust#
5*c217d954SCole Faust# Permission is hereby granted, free of charge, to any person obtaining a copy
6*c217d954SCole Faust# of this software and associated documentation files (the "Software"), to
7*c217d954SCole Faust# deal in the Software without restriction, including without limitation the
8*c217d954SCole Faust# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
9*c217d954SCole Faust# sell copies of the Software, and to permit persons to whom the Software is
10*c217d954SCole Faust# furnished to do so, subject to the following conditions:
11*c217d954SCole Faust#
12*c217d954SCole Faust# The above copyright notice and this permission notice shall be included in all
13*c217d954SCole Faust# copies or substantial portions of the Software.
14*c217d954SCole Faust#
15*c217d954SCole Faust# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16*c217d954SCole Faust# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17*c217d954SCole Faust# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18*c217d954SCole Faust# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19*c217d954SCole Faust# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20*c217d954SCole Faust# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21*c217d954SCole Faust# SOFTWARE.
22*c217d954SCole Faustimport logging
23*c217d954SCole Faustimport os
24*c217d954SCole Faust
25*c217d954SCole Faust
26*c217d954SCole Faustdef is_tflite_model(model_path):
27*c217d954SCole Faust    """Check if a model is of TFLite type
28*c217d954SCole Faust
29*c217d954SCole Faust    Parameters:
30*c217d954SCole Faust    ----------
31*c217d954SCole Faust    model_path: str
32*c217d954SCole Faust        Path to model
33*c217d954SCole Faust
34*c217d954SCole Faust    Returns
35*c217d954SCole Faust    ----------
36*c217d954SCole Faust    bool:
37*c217d954SCole Faust        True if given path is a valid TFLite model
38*c217d954SCole Faust    """
39*c217d954SCole Faust
40*c217d954SCole Faust    try:
41*c217d954SCole Faust        with open(model_path, "rb") as f:
42*c217d954SCole Faust            hdr_bytes = f.read(8)
43*c217d954SCole Faust            hdr_str = hdr_bytes[4:].decode("utf-8")
44*c217d954SCole Faust            if hdr_str == "TFL3":
45*c217d954SCole Faust                return True
46*c217d954SCole Faust            else:
47*c217d954SCole Faust                return False
48*c217d954SCole Faust    except:
49*c217d954SCole Faust        return False
50*c217d954SCole Faust
51*c217d954SCole Faust
52*c217d954SCole Faustdef identify_model_type(model_path):
53*c217d954SCole Faust    """Identify the type of a given deep learning model
54*c217d954SCole Faust
55*c217d954SCole Faust    Parameters:
56*c217d954SCole Faust    ----------
57*c217d954SCole Faust    model_path: str
58*c217d954SCole Faust        Path to model
59*c217d954SCole Faust
60*c217d954SCole Faust    Returns
61*c217d954SCole Faust    ----------
62*c217d954SCole Faust    model_type: str
63*c217d954SCole Faust        String representation of model type or 'None' if type could not be retrieved.
64*c217d954SCole Faust    """
65*c217d954SCole Faust
66*c217d954SCole Faust    if not os.path.exists(model_path):
67*c217d954SCole Faust        logging.warn(f"Provided model {model_path} does not exist!")
68*c217d954SCole Faust        return None
69*c217d954SCole Faust
70*c217d954SCole Faust    if is_tflite_model(model_path):
71*c217d954SCole Faust        model_type = "tflite"
72*c217d954SCole Faust    else:
73*c217d954SCole Faust        logging.warn(logging.warn(f"Provided model {model_path} is not of supported type!"))
74*c217d954SCole Faust        model_type = None
75*c217d954SCole Faust
76*c217d954SCole Faust    return model_type
77