1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2020 Arm Ltd. All rights reserved. 2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT 3*89c4ff92SAndroid Build Coastguard Workerimport pytest 4*89c4ff92SAndroid Build Coastguard Worker 5*89c4ff92SAndroid Build Coastguard Workerimport pyarmnn as ann 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker@pytest.fixture(scope="function") 9*89c4ff92SAndroid Build Coastguard Workerdef network(): 10*89c4ff92SAndroid Build Coastguard Worker return ann.INetwork() 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Workerclass TestIInputIOutputIConnectable: 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker def test_input_slot(self, network): 16*89c4ff92SAndroid Build Coastguard Worker # Create input, addition & output layer 17*89c4ff92SAndroid Build Coastguard Worker input1 = network.AddInputLayer(0, "input1") 18*89c4ff92SAndroid Build Coastguard Worker input2 = network.AddInputLayer(1, "input2") 19*89c4ff92SAndroid Build Coastguard Worker add = network.AddAdditionLayer("addition") 20*89c4ff92SAndroid Build Coastguard Worker output = network.AddOutputLayer(0, "output") 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker # Connect the input/output slots for each layer 23*89c4ff92SAndroid Build Coastguard Worker input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) 24*89c4ff92SAndroid Build Coastguard Worker input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) 25*89c4ff92SAndroid Build Coastguard Worker add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker # Check IInputSlot GetConnection() 28*89c4ff92SAndroid Build Coastguard Worker input_slot = add.GetInputSlot(0) 29*89c4ff92SAndroid Build Coastguard Worker input_slot_connection = input_slot.GetConnection() 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker assert isinstance(input_slot_connection, ann.IOutputSlot) 32*89c4ff92SAndroid Build Coastguard Worker 33*89c4ff92SAndroid Build Coastguard Worker del input_slot_connection 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker assert input_slot.GetConnection() 36*89c4ff92SAndroid Build Coastguard Worker assert isinstance(input_slot.GetConnection(), ann.IOutputSlot) 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker del input_slot 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker assert add.GetInputSlot(0) 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker def test_output_slot(self, network): 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker # Create input, addition & output layer 45*89c4ff92SAndroid Build Coastguard Worker input1 = network.AddInputLayer(0, "input1") 46*89c4ff92SAndroid Build Coastguard Worker input2 = network.AddInputLayer(1, "input2") 47*89c4ff92SAndroid Build Coastguard Worker add = network.AddAdditionLayer("addition") 48*89c4ff92SAndroid Build Coastguard Worker output = network.AddOutputLayer(0, "output") 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker # Connect the input/output slots for each layer 51*89c4ff92SAndroid Build Coastguard Worker input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) 52*89c4ff92SAndroid Build Coastguard Worker input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) 53*89c4ff92SAndroid Build Coastguard Worker add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker # Check IInputSlot GetConnection() 56*89c4ff92SAndroid Build Coastguard Worker add_get_input_connection = add.GetInputSlot(0).GetConnection() 57*89c4ff92SAndroid Build Coastguard Worker output_get_input_connection = output.GetInputSlot(0).GetConnection() 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker # Check IOutputSlot GetConnection() 60*89c4ff92SAndroid Build Coastguard Worker add_get_output_connect = add.GetOutputSlot(0).GetConnection(0) 61*89c4ff92SAndroid Build Coastguard Worker assert isinstance(add_get_output_connect.GetConnection(), ann.IOutputSlot) 62*89c4ff92SAndroid Build Coastguard Worker 63*89c4ff92SAndroid Build Coastguard Worker # Test IOutputSlot GetNumConnections() & CalculateIndexOnOwner() 64*89c4ff92SAndroid Build Coastguard Worker assert add_get_input_connection.GetNumConnections() == 1 65*89c4ff92SAndroid Build Coastguard Worker assert len(add_get_input_connection) == 1 66*89c4ff92SAndroid Build Coastguard Worker assert add_get_input_connection[0] 67*89c4ff92SAndroid Build Coastguard Worker assert add_get_input_connection.CalculateIndexOnOwner() == 0 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker # Check GetOwningLayerGuid(). Check that it is different for add and output layer 70*89c4ff92SAndroid Build Coastguard Worker assert add_get_input_connection.GetOwningLayerGuid() != output_get_input_connection.GetOwningLayerGuid() 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker # Set TensorInfo 73*89c4ff92SAndroid Build Coastguard Worker test_tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32) 74*89c4ff92SAndroid Build Coastguard Worker 75*89c4ff92SAndroid Build Coastguard Worker # Check IsTensorInfoSet() 76*89c4ff92SAndroid Build Coastguard Worker assert not add_get_input_connection.IsTensorInfoSet() 77*89c4ff92SAndroid Build Coastguard Worker add_get_input_connection.SetTensorInfo(test_tensor_info) 78*89c4ff92SAndroid Build Coastguard Worker assert add_get_input_connection.IsTensorInfoSet() 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker # Check GetTensorInfo() 81*89c4ff92SAndroid Build Coastguard Worker output_tensor_info = add_get_input_connection.GetTensorInfo() 82*89c4ff92SAndroid Build Coastguard Worker assert 2 == output_tensor_info.GetNumDimensions() 83*89c4ff92SAndroid Build Coastguard Worker assert 6 == output_tensor_info.GetNumElements() 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker # Check Disconnect() 86*89c4ff92SAndroid Build Coastguard Worker assert output_get_input_connection.GetNumConnections() == 1 # 1 connection to Outputslot0 from input1 87*89c4ff92SAndroid Build Coastguard Worker add.GetOutputSlot(0).Disconnect(output.GetInputSlot(0)) # disconnect add.OutputSlot0 from Output.InputSlot0 88*89c4ff92SAndroid Build Coastguard Worker assert output_get_input_connection.GetNumConnections() == 0 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker def test_output_slot__out_of_range(self, network): 91*89c4ff92SAndroid Build Coastguard Worker # Create input layer to check output slot get item handling 92*89c4ff92SAndroid Build Coastguard Worker input1 = network.AddInputLayer(0, "input1") 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker outputSlot = input1.GetOutputSlot(0) 95*89c4ff92SAndroid Build Coastguard Worker with pytest.raises(ValueError) as err: 96*89c4ff92SAndroid Build Coastguard Worker outputSlot[1] 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Worker assert "Invalid index 1 provided" in str(err.value) 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker def test_iconnectable_guid(self, network): 101*89c4ff92SAndroid Build Coastguard Worker 102*89c4ff92SAndroid Build Coastguard Worker # Check IConnectable GetGuid() 103*89c4ff92SAndroid Build Coastguard Worker # Note Guid can change based on which tests are run so 104*89c4ff92SAndroid Build Coastguard Worker # checking here that each layer does not have the same guid 105*89c4ff92SAndroid Build Coastguard Worker add_id = network.AddAdditionLayer().GetGuid() 106*89c4ff92SAndroid Build Coastguard Worker output_id = network.AddOutputLayer(0).GetGuid() 107*89c4ff92SAndroid Build Coastguard Worker assert add_id != output_id 108*89c4ff92SAndroid Build Coastguard Worker 109*89c4ff92SAndroid Build Coastguard Worker def test_iconnectable_layer_functions(self, network): 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker # Create input, addition & output layer 112*89c4ff92SAndroid Build Coastguard Worker input1 = network.AddInputLayer(0, "input1") 113*89c4ff92SAndroid Build Coastguard Worker input2 = network.AddInputLayer(1, "input2") 114*89c4ff92SAndroid Build Coastguard Worker add = network.AddAdditionLayer("addition") 115*89c4ff92SAndroid Build Coastguard Worker output = network.AddOutputLayer(0, "output") 116*89c4ff92SAndroid Build Coastguard Worker 117*89c4ff92SAndroid Build Coastguard Worker # Check GetNumInputSlots(), GetName() & GetNumOutputSlots() 118*89c4ff92SAndroid Build Coastguard Worker assert input1.GetNumInputSlots() == 0 119*89c4ff92SAndroid Build Coastguard Worker assert input1.GetName() == "input1" 120*89c4ff92SAndroid Build Coastguard Worker assert input1.GetNumOutputSlots() == 1 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker assert input2.GetNumInputSlots() == 0 123*89c4ff92SAndroid Build Coastguard Worker assert input2.GetName() == "input2" 124*89c4ff92SAndroid Build Coastguard Worker assert input2.GetNumOutputSlots() == 1 125*89c4ff92SAndroid Build Coastguard Worker 126*89c4ff92SAndroid Build Coastguard Worker assert add.GetNumInputSlots() == 2 127*89c4ff92SAndroid Build Coastguard Worker assert add.GetName() == "addition" 128*89c4ff92SAndroid Build Coastguard Worker assert add.GetNumOutputSlots() == 1 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker assert output.GetNumInputSlots() == 1 131*89c4ff92SAndroid Build Coastguard Worker assert output.GetName() == "output" 132*89c4ff92SAndroid Build Coastguard Worker assert output.GetNumOutputSlots() == 0 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker # Check GetOutputSlot() 135*89c4ff92SAndroid Build Coastguard Worker input1_get_output = input1.GetOutputSlot(0) 136*89c4ff92SAndroid Build Coastguard Worker assert input1_get_output.GetNumConnections() == 0 137*89c4ff92SAndroid Build Coastguard Worker assert len(input1_get_output) == 0 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker # Check GetInputSlot() 140*89c4ff92SAndroid Build Coastguard Worker add_get_input = add.GetInputSlot(0) 141*89c4ff92SAndroid Build Coastguard Worker add_get_input.GetConnection() 142*89c4ff92SAndroid Build Coastguard Worker assert isinstance(add_get_input, ann.IInputSlot) 143