1 /* 2 * Copyright (c) 2018-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #include "arm_compute/graph/Graph.h" 25 26 namespace arm_compute 27 { 28 namespace graph 29 { Graph(GraphID id,std::string name)30 Graph::Graph(GraphID id, std::string name) 31 : _id(id), _name(std::move(name)), _nodes(), _edges(), _tensors(), _tagged_nodes(), _mtx() 32 { 33 } 34 remove_node(NodeID nid)35 bool Graph::remove_node(NodeID nid) 36 { 37 if(nid >= _nodes.size()) 38 { 39 return false; 40 } 41 42 std::unique_ptr<INode> &node = _nodes[nid]; 43 44 if(node) 45 { 46 // Remove input connections 47 for(auto &input_eid : node->_input_edges) 48 { 49 remove_connection(input_eid); 50 } 51 52 // Remove output connections 53 std::set<EdgeID> output_edges_copy = node->output_edges(); 54 for(auto &output_eid : output_edges_copy) 55 { 56 remove_connection(output_eid); 57 } 58 59 // Remove nid from tagged nodes 60 std::vector<NodeID> &tnodes = _tagged_nodes.at(node->type()); 61 tnodes.erase(std::remove(tnodes.begin(), tnodes.end(), nid), tnodes.end()); 62 } 63 64 node = nullptr; 65 66 return true; 67 } 68 add_connection(NodeID source,size_t source_idx,NodeID sink,size_t sink_idx)69 EdgeID Graph::add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx) 70 { 71 arm_compute::lock_guard<arm_compute::Mutex> lock(_mtx); 72 73 // Check if node index is valid, if node exists and finally if the connection index is valid 74 ARM_COMPUTE_ERROR_ON((source >= _nodes.size()) || (_nodes[source] == nullptr) || (source_idx >= _nodes[source]->num_outputs())); 75 ARM_COMPUTE_ERROR_ON((sink >= _nodes.size()) || (_nodes[sink] == nullptr) || (sink_idx >= _nodes[sink]->num_inputs())); 76 77 // Get nodes 78 std::unique_ptr<INode> &source_node = _nodes[source]; 79 std::unique_ptr<INode> &sink_node = _nodes[sink]; 80 81 // Check for duplicate connections (Check only sink node) 82 Edge *sink_node_edge = sink_node->input_edge(sink_idx); 83 if((sink_node_edge != nullptr) && (sink_node_edge->producer_id() == source) && (sink_node_edge->producer_idx() == source_idx) 84 && (sink_node_edge->consumer_id() == sink) && (sink_node_edge->consumer_idx() == sink_idx)) 85 { 86 return sink_node_edge->id(); 87 } 88 89 // Check if there is already a tensor associated with output if not create one 90 TensorID tid = source_node->output_id(source_idx); 91 if(tid == NullTensorID) 92 { 93 tid = create_tensor(); 94 } 95 std::unique_ptr<Tensor> &tensor = _tensors[tid]; 96 97 // Create connections 98 EdgeID eid = _edges.size(); 99 auto connection = std::make_unique<Edge>(eid, source_node.get(), source_idx, sink_node.get(), sink_idx, tensor.get()); 100 _edges.push_back(std::move(connection)); 101 102 // Add connections to source and sink nodes 103 source_node->_output_edges.insert(eid); 104 sink_node->_input_edges[sink_idx] = eid; 105 106 // Set tensor output node 107 source_node->_outputs[source_idx] = tid; 108 109 // Bind tensor to the edge 110 tensor->bind_edge(eid); 111 112 // Try and propagate shapes in sink node 113 sink_node->forward_descriptors(); 114 115 return eid; 116 } 117 remove_connection(EdgeID eid)118 bool Graph::remove_connection(EdgeID eid) 119 { 120 if(eid >= _edges.size()) 121 { 122 return false; 123 } 124 125 std::unique_ptr<Edge> &edge = _edges[eid]; 126 127 // Remove node connections 128 if(edge != nullptr) 129 { 130 // Get tensor bound to the edge 131 if(edge->tensor() != nullptr) 132 { 133 edge->tensor()->unbind_edge(eid); 134 } 135 136 // Remove edges from source node 137 if(edge->producer() != nullptr) 138 { 139 edge->producer()->_output_edges.erase(eid); 140 } 141 142 // Remove edges from sink node 143 if((edge->consumer() != nullptr) && (edge->consumer_idx() < edge->consumer()->_input_edges.size())) 144 { 145 edge->consumer()->_input_edges[edge->consumer_idx()] = EmptyEdgeID; 146 } 147 } 148 149 // Clear edge 150 edge = nullptr; 151 152 return true; 153 } 154 create_tensor(const TensorDescriptor & desc)155 TensorID Graph::create_tensor(const TensorDescriptor &desc) 156 { 157 TensorID tid = _tensors.size(); 158 auto tensor = std::make_unique<Tensor>(tid, desc); 159 _tensors.push_back(std::move(tensor)); 160 161 return tid; 162 } 163 name() const164 std::string Graph::name() const 165 { 166 return _name; 167 } 168 id() const169 GraphID Graph::id() const 170 { 171 return _id; 172 } 173 nodes(NodeType type)174 const std::vector<NodeID> &Graph::nodes(NodeType type) 175 { 176 return _tagged_nodes[type]; 177 } 178 nodes()179 std::vector<std::unique_ptr<INode>> &Graph::nodes() 180 { 181 return _nodes; 182 } 183 nodes() const184 const std::vector<std::unique_ptr<INode>> &Graph::nodes() const 185 { 186 return _nodes; 187 } 188 edges() const189 const std::vector<std::unique_ptr<Edge>> &Graph::edges() const 190 { 191 return _edges; 192 } 193 tensors()194 std::vector<std::unique_ptr<Tensor>> &Graph::tensors() 195 { 196 return _tensors; 197 } 198 tensors() const199 const std::vector<std::unique_ptr<Tensor>> &Graph::tensors() const 200 { 201 return _tensors; 202 } 203 node(NodeID id) const204 const INode *Graph::node(NodeID id) const 205 { 206 return (id >= _nodes.size()) ? nullptr : _nodes[id].get(); 207 } 208 node(NodeID id)209 INode *Graph::node(NodeID id) 210 { 211 return (id >= _nodes.size()) ? nullptr : _nodes[id].get(); 212 } 213 edge(EdgeID id) const214 const Edge *Graph::edge(EdgeID id) const 215 { 216 return (id >= _edges.size()) ? nullptr : _edges[id].get(); 217 } 218 edge(EdgeID id)219 Edge *Graph::edge(EdgeID id) 220 { 221 return (id >= _edges.size()) ? nullptr : _edges[id].get(); 222 } 223 tensor(TensorID id) const224 const Tensor *Graph::tensor(TensorID id) const 225 { 226 return (id >= _tensors.size()) ? nullptr : _tensors[id].get(); 227 } 228 tensor(TensorID id)229 Tensor *Graph::tensor(TensorID id) 230 { 231 return (id >= _tensors.size()) ? nullptr : _tensors[id].get(); 232 } 233 } // namespace graph 234 } // namespace arm_compute