xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_modeloption.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
3*89c4ff92SAndroid Build Coastguard Workerimport pytest
4*89c4ff92SAndroid Build Coastguard Worker
5*89c4ff92SAndroid Build Coastguard Workerfrom pyarmnn import BackendOptions, BackendOption, BackendId, OptimizerOptions, ShapeInferenceMethod_InferAndValidate
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.parametrize("data", (True, -100, 128, 0.12345, 'string'))
9*89c4ff92SAndroid Build Coastguard Workerdef test_backend_option_ctor(data):
10*89c4ff92SAndroid Build Coastguard Worker    bo = BackendOption("name", data)
11*89c4ff92SAndroid Build Coastguard Worker    assert "name" == bo.GetName()
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Workerdef test_backend_options_ctor():
15*89c4ff92SAndroid Build Coastguard Worker    backend_id = BackendId('a')
16*89c4ff92SAndroid Build Coastguard Worker    bos = BackendOptions(backend_id)
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker    assert 'a' == str(bos.GetBackendId())
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker    another_bos = BackendOptions(bos)
21*89c4ff92SAndroid Build Coastguard Worker    assert 'a' == str(another_bos.GetBackendId())
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Workerdef test_backend_options_add():
25*89c4ff92SAndroid Build Coastguard Worker    backend_id = BackendId('a')
26*89c4ff92SAndroid Build Coastguard Worker    bos = BackendOptions(backend_id)
27*89c4ff92SAndroid Build Coastguard Worker    bo = BackendOption("name", 1)
28*89c4ff92SAndroid Build Coastguard Worker    bos.AddOption(bo)
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker    assert 1 == bos.GetOptionCount()
31*89c4ff92SAndroid Build Coastguard Worker    assert 1 == len(bos)
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker    assert 'name' == bos[0].GetName()
34*89c4ff92SAndroid Build Coastguard Worker    assert 'name' == bos.GetOption(0).GetName()
35*89c4ff92SAndroid Build Coastguard Worker    for option in bos:
36*89c4ff92SAndroid Build Coastguard Worker        assert 'name' == option.GetName()
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker    bos.AddOption(BackendOption("name2", 2))
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker    assert 2 == bos.GetOptionCount()
41*89c4ff92SAndroid Build Coastguard Worker    assert 2 == len(bos)
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Workerdef test_backend_option_ownership():
45*89c4ff92SAndroid Build Coastguard Worker    backend_id = BackendId('b')
46*89c4ff92SAndroid Build Coastguard Worker    bos = BackendOptions(backend_id)
47*89c4ff92SAndroid Build Coastguard Worker    bo = BackendOption('option', True)
48*89c4ff92SAndroid Build Coastguard Worker    bos.AddOption(bo)
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker    assert bo.thisown
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker    del bo
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker    assert 1 == bos.GetOptionCount()
55*89c4ff92SAndroid Build Coastguard Worker    option = bos[0]
56*89c4ff92SAndroid Build Coastguard Worker    assert not option.thisown
57*89c4ff92SAndroid Build Coastguard Worker    assert 'option' == option.GetName()
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker    del option
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker    option_again = bos[0]
62*89c4ff92SAndroid Build Coastguard Worker    assert not option_again.thisown
63*89c4ff92SAndroid Build Coastguard Worker    assert 'option' == option_again.GetName()
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Workerdef test_optimizer_options_with_model_opt():
67*89c4ff92SAndroid Build Coastguard Worker    a = BackendOptions(BackendId('a'))
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker    oo = OptimizerOptions(True,
70*89c4ff92SAndroid Build Coastguard Worker                          False,
71*89c4ff92SAndroid Build Coastguard Worker                          False,
72*89c4ff92SAndroid Build Coastguard Worker                          ShapeInferenceMethod_InferAndValidate,
73*89c4ff92SAndroid Build Coastguard Worker                          True,
74*89c4ff92SAndroid Build Coastguard Worker                          [a],
75*89c4ff92SAndroid Build Coastguard Worker                          True)
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker    mo = oo.m_ModelOptions
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker    assert 1 == len(mo)
80*89c4ff92SAndroid Build Coastguard Worker    assert 'a' == str(mo[0].GetBackendId())
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker    b = BackendOptions(BackendId('b'))
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker    c = BackendOptions(BackendId('c'))
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker    oo.m_ModelOptions = (a, b, c)
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker    mo = oo.m_ModelOptions
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker    assert 3 == len(oo.m_ModelOptions)
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker    assert 'a' == str(mo[0].GetBackendId())
93*89c4ff92SAndroid Build Coastguard Worker    assert 'b' == str(mo[1].GetBackendId())
94*89c4ff92SAndroid Build Coastguard Worker    assert 'c' == str(mo[2].GetBackendId())
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Workerdef test_optimizer_option_default():
98*89c4ff92SAndroid Build Coastguard Worker    oo = OptimizerOptions(True,
99*89c4ff92SAndroid Build Coastguard Worker                          False,
100*89c4ff92SAndroid Build Coastguard Worker                          False,
101*89c4ff92SAndroid Build Coastguard Worker                          ShapeInferenceMethod_InferAndValidate,
102*89c4ff92SAndroid Build Coastguard Worker                          True)
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker    assert 0 == len(oo.m_ModelOptions)
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Workerdef test_optimizer_options_fail():
108*89c4ff92SAndroid Build Coastguard Worker    a = BackendOptions(BackendId('a'))
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker    with pytest.raises(TypeError) as err:
111*89c4ff92SAndroid Build Coastguard Worker        OptimizerOptions(True,
112*89c4ff92SAndroid Build Coastguard Worker                         False,
113*89c4ff92SAndroid Build Coastguard Worker                         False,
114*89c4ff92SAndroid Build Coastguard Worker                         ShapeInferenceMethod_InferAndValidate,
115*89c4ff92SAndroid Build Coastguard Worker                         True,
116*89c4ff92SAndroid Build Coastguard Worker                         a,
117*89c4ff92SAndroid Build Coastguard Worker                         True)
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker    assert "Wrong number or type of arguments" in str(err.value)
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker    with pytest.raises(TypeError) as err:
122*89c4ff92SAndroid Build Coastguard Worker        oo = OptimizerOptions(True,
123*89c4ff92SAndroid Build Coastguard Worker                              False,
124*89c4ff92SAndroid Build Coastguard Worker                              False,
125*89c4ff92SAndroid Build Coastguard Worker                              ShapeInferenceMethod_InferAndValidate,
126*89c4ff92SAndroid Build Coastguard Worker                              True)
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker        oo.m_ModelOptions = 'nonsense'
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker    assert "in method 'OptimizerOptions_m_ModelOptions_set', argument 2" in str(err.value)
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker    with pytest.raises(TypeError) as err:
133*89c4ff92SAndroid Build Coastguard Worker        oo = OptimizerOptions(True,
134*89c4ff92SAndroid Build Coastguard Worker                              False,
135*89c4ff92SAndroid Build Coastguard Worker                              False,
136*89c4ff92SAndroid Build Coastguard Worker                              ShapeInferenceMethod_InferAndValidate,
137*89c4ff92SAndroid Build Coastguard Worker                              True)
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker        oo.m_ModelOptions = ['nonsense', a]
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker    assert "in method 'OptimizerOptions_m_ModelOptions_set', argument 2" in str(err.value)
142