1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT 3*89c4ff92SAndroid Build Coastguard Workerimport pytest 4*89c4ff92SAndroid Build Coastguard Worker 5*89c4ff92SAndroid Build Coastguard Workerfrom pyarmnn import BackendOptions, BackendOption, BackendId, OptimizerOptions, ShapeInferenceMethod_InferAndValidate 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.parametrize("data", (True, -100, 128, 0.12345, 'string')) 9*89c4ff92SAndroid Build Coastguard Workerdef test_backend_option_ctor(data): 10*89c4ff92SAndroid Build Coastguard Worker bo = BackendOption("name", data) 11*89c4ff92SAndroid Build Coastguard Worker assert "name" == bo.GetName() 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Workerdef test_backend_options_ctor(): 15*89c4ff92SAndroid Build Coastguard Worker backend_id = BackendId('a') 16*89c4ff92SAndroid Build Coastguard Worker bos = BackendOptions(backend_id) 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker assert 'a' == str(bos.GetBackendId()) 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker another_bos = BackendOptions(bos) 21*89c4ff92SAndroid Build Coastguard Worker assert 'a' == str(another_bos.GetBackendId()) 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Workerdef test_backend_options_add(): 25*89c4ff92SAndroid Build Coastguard Worker backend_id = BackendId('a') 26*89c4ff92SAndroid Build Coastguard Worker bos = BackendOptions(backend_id) 27*89c4ff92SAndroid Build Coastguard Worker bo = BackendOption("name", 1) 28*89c4ff92SAndroid Build Coastguard Worker bos.AddOption(bo) 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker assert 1 == bos.GetOptionCount() 31*89c4ff92SAndroid Build Coastguard Worker assert 1 == len(bos) 32*89c4ff92SAndroid Build Coastguard Worker 33*89c4ff92SAndroid Build Coastguard Worker assert 'name' == bos[0].GetName() 34*89c4ff92SAndroid Build Coastguard Worker assert 'name' == bos.GetOption(0).GetName() 35*89c4ff92SAndroid Build Coastguard Worker for option in bos: 36*89c4ff92SAndroid Build Coastguard Worker assert 'name' == option.GetName() 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker bos.AddOption(BackendOption("name2", 2)) 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker assert 2 == bos.GetOptionCount() 41*89c4ff92SAndroid Build Coastguard Worker assert 2 == len(bos) 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Workerdef test_backend_option_ownership(): 45*89c4ff92SAndroid Build Coastguard Worker backend_id = BackendId('b') 46*89c4ff92SAndroid Build Coastguard Worker bos = BackendOptions(backend_id) 47*89c4ff92SAndroid Build Coastguard Worker bo = BackendOption('option', True) 48*89c4ff92SAndroid Build Coastguard Worker bos.AddOption(bo) 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker assert bo.thisown 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker del bo 53*89c4ff92SAndroid Build Coastguard Worker 54*89c4ff92SAndroid Build Coastguard Worker assert 1 == bos.GetOptionCount() 55*89c4ff92SAndroid Build Coastguard Worker option = bos[0] 56*89c4ff92SAndroid Build Coastguard Worker assert not option.thisown 57*89c4ff92SAndroid Build Coastguard Worker assert 'option' == option.GetName() 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker del option 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker option_again = bos[0] 62*89c4ff92SAndroid Build Coastguard Worker assert not option_again.thisown 63*89c4ff92SAndroid Build Coastguard Worker assert 'option' == option_again.GetName() 64*89c4ff92SAndroid Build Coastguard Worker 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Workerdef test_optimizer_options_with_model_opt(): 67*89c4ff92SAndroid Build Coastguard Worker a = BackendOptions(BackendId('a')) 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker oo = OptimizerOptions(True, 70*89c4ff92SAndroid Build Coastguard Worker False, 71*89c4ff92SAndroid Build Coastguard Worker False, 72*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod_InferAndValidate, 73*89c4ff92SAndroid Build Coastguard Worker True, 74*89c4ff92SAndroid Build Coastguard Worker [a], 75*89c4ff92SAndroid Build Coastguard Worker True) 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker mo = oo.m_ModelOptions 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker assert 1 == len(mo) 80*89c4ff92SAndroid Build Coastguard Worker assert 'a' == str(mo[0].GetBackendId()) 81*89c4ff92SAndroid Build Coastguard Worker 82*89c4ff92SAndroid Build Coastguard Worker b = BackendOptions(BackendId('b')) 83*89c4ff92SAndroid Build Coastguard Worker 84*89c4ff92SAndroid Build Coastguard Worker c = BackendOptions(BackendId('c')) 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker oo.m_ModelOptions = (a, b, c) 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker mo = oo.m_ModelOptions 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker assert 3 == len(oo.m_ModelOptions) 91*89c4ff92SAndroid Build Coastguard Worker 92*89c4ff92SAndroid Build Coastguard Worker assert 'a' == str(mo[0].GetBackendId()) 93*89c4ff92SAndroid Build Coastguard Worker assert 'b' == str(mo[1].GetBackendId()) 94*89c4ff92SAndroid Build Coastguard Worker assert 'c' == str(mo[2].GetBackendId()) 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Worker 97*89c4ff92SAndroid Build Coastguard Workerdef test_optimizer_option_default(): 98*89c4ff92SAndroid Build Coastguard Worker oo = OptimizerOptions(True, 99*89c4ff92SAndroid Build Coastguard Worker False, 100*89c4ff92SAndroid Build Coastguard Worker False, 101*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod_InferAndValidate, 102*89c4ff92SAndroid Build Coastguard Worker True) 103*89c4ff92SAndroid Build Coastguard Worker 104*89c4ff92SAndroid Build Coastguard Worker assert 0 == len(oo.m_ModelOptions) 105*89c4ff92SAndroid Build Coastguard Worker 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Workerdef test_optimizer_options_fail(): 108*89c4ff92SAndroid Build Coastguard Worker a = BackendOptions(BackendId('a')) 109*89c4ff92SAndroid Build Coastguard Worker 110*89c4ff92SAndroid Build Coastguard Worker with pytest.raises(TypeError) as err: 111*89c4ff92SAndroid Build Coastguard Worker OptimizerOptions(True, 112*89c4ff92SAndroid Build Coastguard Worker False, 113*89c4ff92SAndroid Build Coastguard Worker False, 114*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod_InferAndValidate, 115*89c4ff92SAndroid Build Coastguard Worker True, 116*89c4ff92SAndroid Build Coastguard Worker a, 117*89c4ff92SAndroid Build Coastguard Worker True) 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker assert "Wrong number or type of arguments" in str(err.value) 120*89c4ff92SAndroid Build Coastguard Worker 121*89c4ff92SAndroid Build Coastguard Worker with pytest.raises(TypeError) as err: 122*89c4ff92SAndroid Build Coastguard Worker oo = OptimizerOptions(True, 123*89c4ff92SAndroid Build Coastguard Worker False, 124*89c4ff92SAndroid Build Coastguard Worker False, 125*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod_InferAndValidate, 126*89c4ff92SAndroid Build Coastguard Worker True) 127*89c4ff92SAndroid Build Coastguard Worker 128*89c4ff92SAndroid Build Coastguard Worker oo.m_ModelOptions = 'nonsense' 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker assert "in method 'OptimizerOptions_m_ModelOptions_set', argument 2" in str(err.value) 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Worker with pytest.raises(TypeError) as err: 133*89c4ff92SAndroid Build Coastguard Worker oo = OptimizerOptions(True, 134*89c4ff92SAndroid Build Coastguard Worker False, 135*89c4ff92SAndroid Build Coastguard Worker False, 136*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod_InferAndValidate, 137*89c4ff92SAndroid Build Coastguard Worker True) 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker oo.m_ModelOptions = ['nonsense', a] 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker assert "in method 'OptimizerOptions_m_ModelOptions_set', argument 2" in str(err.value) 142