xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_iconnectable.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2020 Arm Ltd. All rights reserved.
2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
3*89c4ff92SAndroid Build Coastguard Workerimport pytest
4*89c4ff92SAndroid Build Coastguard Worker
5*89c4ff92SAndroid Build Coastguard Workerimport pyarmnn as ann
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker@pytest.fixture(scope="function")
9*89c4ff92SAndroid Build Coastguard Workerdef network():
10*89c4ff92SAndroid Build Coastguard Worker    return ann.INetwork()
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Workerclass TestIInputIOutputIConnectable:
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker    def test_input_slot(self, network):
16*89c4ff92SAndroid Build Coastguard Worker        # Create input, addition & output layer
17*89c4ff92SAndroid Build Coastguard Worker        input1 = network.AddInputLayer(0, "input1")
18*89c4ff92SAndroid Build Coastguard Worker        input2 = network.AddInputLayer(1, "input2")
19*89c4ff92SAndroid Build Coastguard Worker        add = network.AddAdditionLayer("addition")
20*89c4ff92SAndroid Build Coastguard Worker        output = network.AddOutputLayer(0, "output")
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker        # Connect the input/output slots for each layer
23*89c4ff92SAndroid Build Coastguard Worker        input1.GetOutputSlot(0).Connect(add.GetInputSlot(0))
24*89c4ff92SAndroid Build Coastguard Worker        input2.GetOutputSlot(0).Connect(add.GetInputSlot(1))
25*89c4ff92SAndroid Build Coastguard Worker        add.GetOutputSlot(0).Connect(output.GetInputSlot(0))
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker        # Check IInputSlot GetConnection()
28*89c4ff92SAndroid Build Coastguard Worker        input_slot = add.GetInputSlot(0)
29*89c4ff92SAndroid Build Coastguard Worker        input_slot_connection = input_slot.GetConnection()
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker        assert isinstance(input_slot_connection, ann.IOutputSlot)
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker        del input_slot_connection
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker        assert input_slot.GetConnection()
36*89c4ff92SAndroid Build Coastguard Worker        assert isinstance(input_slot.GetConnection(), ann.IOutputSlot)
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker        del input_slot
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker        assert add.GetInputSlot(0)
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker    def test_output_slot(self, network):
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker        # Create input, addition & output layer
45*89c4ff92SAndroid Build Coastguard Worker        input1 = network.AddInputLayer(0, "input1")
46*89c4ff92SAndroid Build Coastguard Worker        input2 = network.AddInputLayer(1, "input2")
47*89c4ff92SAndroid Build Coastguard Worker        add = network.AddAdditionLayer("addition")
48*89c4ff92SAndroid Build Coastguard Worker        output = network.AddOutputLayer(0, "output")
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker        # Connect the input/output slots for each layer
51*89c4ff92SAndroid Build Coastguard Worker        input1.GetOutputSlot(0).Connect(add.GetInputSlot(0))
52*89c4ff92SAndroid Build Coastguard Worker        input2.GetOutputSlot(0).Connect(add.GetInputSlot(1))
53*89c4ff92SAndroid Build Coastguard Worker        add.GetOutputSlot(0).Connect(output.GetInputSlot(0))
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker        # Check IInputSlot GetConnection()
56*89c4ff92SAndroid Build Coastguard Worker        add_get_input_connection = add.GetInputSlot(0).GetConnection()
57*89c4ff92SAndroid Build Coastguard Worker        output_get_input_connection = output.GetInputSlot(0).GetConnection()
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker        # Check IOutputSlot GetConnection()
60*89c4ff92SAndroid Build Coastguard Worker        add_get_output_connect = add.GetOutputSlot(0).GetConnection(0)
61*89c4ff92SAndroid Build Coastguard Worker        assert isinstance(add_get_output_connect.GetConnection(), ann.IOutputSlot)
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker        # Test IOutputSlot GetNumConnections() & CalculateIndexOnOwner()
64*89c4ff92SAndroid Build Coastguard Worker        assert add_get_input_connection.GetNumConnections() == 1
65*89c4ff92SAndroid Build Coastguard Worker        assert len(add_get_input_connection) == 1
66*89c4ff92SAndroid Build Coastguard Worker        assert add_get_input_connection[0]
67*89c4ff92SAndroid Build Coastguard Worker        assert add_get_input_connection.CalculateIndexOnOwner() == 0
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker        # Check GetOwningLayerGuid(). Check that it is different for add and output layer
70*89c4ff92SAndroid Build Coastguard Worker        assert add_get_input_connection.GetOwningLayerGuid() != output_get_input_connection.GetOwningLayerGuid()
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker        # Set TensorInfo
73*89c4ff92SAndroid Build Coastguard Worker        test_tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32)
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker        # Check IsTensorInfoSet()
76*89c4ff92SAndroid Build Coastguard Worker        assert not add_get_input_connection.IsTensorInfoSet()
77*89c4ff92SAndroid Build Coastguard Worker        add_get_input_connection.SetTensorInfo(test_tensor_info)
78*89c4ff92SAndroid Build Coastguard Worker        assert add_get_input_connection.IsTensorInfoSet()
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker        # Check GetTensorInfo()
81*89c4ff92SAndroid Build Coastguard Worker        output_tensor_info = add_get_input_connection.GetTensorInfo()
82*89c4ff92SAndroid Build Coastguard Worker        assert 2 == output_tensor_info.GetNumDimensions()
83*89c4ff92SAndroid Build Coastguard Worker        assert 6 == output_tensor_info.GetNumElements()
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker        # Check Disconnect()
86*89c4ff92SAndroid Build Coastguard Worker        assert output_get_input_connection.GetNumConnections() == 1  # 1 connection to Outputslot0 from input1
87*89c4ff92SAndroid Build Coastguard Worker        add.GetOutputSlot(0).Disconnect(output.GetInputSlot(0))  # disconnect add.OutputSlot0 from Output.InputSlot0
88*89c4ff92SAndroid Build Coastguard Worker        assert output_get_input_connection.GetNumConnections() == 0
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker    def test_output_slot__out_of_range(self, network):
91*89c4ff92SAndroid Build Coastguard Worker        # Create input layer to check output slot get item handling
92*89c4ff92SAndroid Build Coastguard Worker        input1 = network.AddInputLayer(0, "input1")
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker        outputSlot = input1.GetOutputSlot(0)
95*89c4ff92SAndroid Build Coastguard Worker        with pytest.raises(ValueError) as err:
96*89c4ff92SAndroid Build Coastguard Worker                outputSlot[1]
97*89c4ff92SAndroid Build Coastguard Worker
98*89c4ff92SAndroid Build Coastguard Worker        assert "Invalid index 1 provided" in str(err.value)
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker    def test_iconnectable_guid(self, network):
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker        # Check IConnectable GetGuid()
103*89c4ff92SAndroid Build Coastguard Worker        # Note Guid can change based on which tests are run so
104*89c4ff92SAndroid Build Coastguard Worker        # checking here that each layer does not have the same guid
105*89c4ff92SAndroid Build Coastguard Worker        add_id = network.AddAdditionLayer().GetGuid()
106*89c4ff92SAndroid Build Coastguard Worker        output_id = network.AddOutputLayer(0).GetGuid()
107*89c4ff92SAndroid Build Coastguard Worker        assert add_id != output_id
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker    def test_iconnectable_layer_functions(self, network):
110*89c4ff92SAndroid Build Coastguard Worker
111*89c4ff92SAndroid Build Coastguard Worker        # Create input, addition & output layer
112*89c4ff92SAndroid Build Coastguard Worker        input1 = network.AddInputLayer(0, "input1")
113*89c4ff92SAndroid Build Coastguard Worker        input2 = network.AddInputLayer(1, "input2")
114*89c4ff92SAndroid Build Coastguard Worker        add = network.AddAdditionLayer("addition")
115*89c4ff92SAndroid Build Coastguard Worker        output = network.AddOutputLayer(0, "output")
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker        # Check GetNumInputSlots(), GetName() & GetNumOutputSlots()
118*89c4ff92SAndroid Build Coastguard Worker        assert input1.GetNumInputSlots() == 0
119*89c4ff92SAndroid Build Coastguard Worker        assert input1.GetName() == "input1"
120*89c4ff92SAndroid Build Coastguard Worker        assert input1.GetNumOutputSlots() == 1
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker        assert input2.GetNumInputSlots() == 0
123*89c4ff92SAndroid Build Coastguard Worker        assert input2.GetName() == "input2"
124*89c4ff92SAndroid Build Coastguard Worker        assert input2.GetNumOutputSlots() == 1
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker        assert add.GetNumInputSlots() == 2
127*89c4ff92SAndroid Build Coastguard Worker        assert add.GetName() == "addition"
128*89c4ff92SAndroid Build Coastguard Worker        assert add.GetNumOutputSlots() == 1
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker        assert output.GetNumInputSlots() == 1
131*89c4ff92SAndroid Build Coastguard Worker        assert output.GetName() == "output"
132*89c4ff92SAndroid Build Coastguard Worker        assert output.GetNumOutputSlots() == 0
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker        # Check GetOutputSlot()
135*89c4ff92SAndroid Build Coastguard Worker        input1_get_output = input1.GetOutputSlot(0)
136*89c4ff92SAndroid Build Coastguard Worker        assert input1_get_output.GetNumConnections() == 0
137*89c4ff92SAndroid Build Coastguard Worker        assert len(input1_get_output) == 0
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker        # Check GetInputSlot()
140*89c4ff92SAndroid Build Coastguard Worker        add_get_input = add.GetInputSlot(0)
141*89c4ff92SAndroid Build Coastguard Worker        add_get_input.GetConnection()
142*89c4ff92SAndroid Build Coastguard Worker        assert isinstance(add_get_input, ann.IInputSlot)
143