xref: /aosp_15_r20/external/ComputeLibrary/src/graph/Graph.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1  /*
2   * Copyright (c) 2018-2020 Arm Limited.
3   *
4   * SPDX-License-Identifier: MIT
5   *
6   * Permission is hereby granted, free of charge, to any person obtaining a copy
7   * of this software and associated documentation files (the "Software"), to
8   * deal in the Software without restriction, including without limitation the
9   * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10   * sell copies of the Software, and to permit persons to whom the Software is
11   * furnished to do so, subject to the following conditions:
12   *
13   * The above copyright notice and this permission notice shall be included in all
14   * copies or substantial portions of the Software.
15   *
16   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22   * SOFTWARE.
23   */
24  #include "arm_compute/graph/Graph.h"
25  
26  namespace arm_compute
27  {
28  namespace graph
29  {
Graph(GraphID id,std::string name)30  Graph::Graph(GraphID id, std::string name)
31      : _id(id), _name(std::move(name)), _nodes(), _edges(), _tensors(), _tagged_nodes(), _mtx()
32  {
33  }
34  
remove_node(NodeID nid)35  bool Graph::remove_node(NodeID nid)
36  {
37      if(nid >= _nodes.size())
38      {
39          return false;
40      }
41  
42      std::unique_ptr<INode> &node = _nodes[nid];
43  
44      if(node)
45      {
46          // Remove input connections
47          for(auto &input_eid : node->_input_edges)
48          {
49              remove_connection(input_eid);
50          }
51  
52          // Remove output connections
53          std::set<EdgeID> output_edges_copy = node->output_edges();
54          for(auto &output_eid : output_edges_copy)
55          {
56              remove_connection(output_eid);
57          }
58  
59          // Remove nid from tagged nodes
60          std::vector<NodeID> &tnodes = _tagged_nodes.at(node->type());
61          tnodes.erase(std::remove(tnodes.begin(), tnodes.end(), nid), tnodes.end());
62      }
63  
64      node = nullptr;
65  
66      return true;
67  }
68  
add_connection(NodeID source,size_t source_idx,NodeID sink,size_t sink_idx)69  EdgeID Graph::add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx)
70  {
71      arm_compute::lock_guard<arm_compute::Mutex> lock(_mtx);
72  
73      // Check if node index is valid, if node exists and finally if the connection index is valid
74      ARM_COMPUTE_ERROR_ON((source >= _nodes.size()) || (_nodes[source] == nullptr) || (source_idx >= _nodes[source]->num_outputs()));
75      ARM_COMPUTE_ERROR_ON((sink >= _nodes.size()) || (_nodes[sink] == nullptr) || (sink_idx >= _nodes[sink]->num_inputs()));
76  
77      // Get nodes
78      std::unique_ptr<INode> &source_node = _nodes[source];
79      std::unique_ptr<INode> &sink_node   = _nodes[sink];
80  
81      // Check for duplicate connections (Check only sink node)
82      Edge *sink_node_edge = sink_node->input_edge(sink_idx);
83      if((sink_node_edge != nullptr) && (sink_node_edge->producer_id() == source) && (sink_node_edge->producer_idx() == source_idx)
84         && (sink_node_edge->consumer_id() == sink) && (sink_node_edge->consumer_idx() == sink_idx))
85      {
86          return sink_node_edge->id();
87      }
88  
89      // Check if there is already a tensor associated with output if not create one
90      TensorID tid = source_node->output_id(source_idx);
91      if(tid == NullTensorID)
92      {
93          tid = create_tensor();
94      }
95      std::unique_ptr<Tensor> &tensor = _tensors[tid];
96  
97      // Create connections
98      EdgeID eid        = _edges.size();
99      auto   connection = std::make_unique<Edge>(eid, source_node.get(), source_idx, sink_node.get(), sink_idx, tensor.get());
100      _edges.push_back(std::move(connection));
101  
102      // Add connections to source and sink nodes
103      source_node->_output_edges.insert(eid);
104      sink_node->_input_edges[sink_idx] = eid;
105  
106      // Set tensor output node
107      source_node->_outputs[source_idx] = tid;
108  
109      // Bind tensor to the edge
110      tensor->bind_edge(eid);
111  
112      // Try and propagate shapes in sink node
113      sink_node->forward_descriptors();
114  
115      return eid;
116  }
117  
remove_connection(EdgeID eid)118  bool Graph::remove_connection(EdgeID eid)
119  {
120      if(eid >= _edges.size())
121      {
122          return false;
123      }
124  
125      std::unique_ptr<Edge> &edge = _edges[eid];
126  
127      // Remove node connections
128      if(edge != nullptr)
129      {
130          // Get tensor bound to the edge
131          if(edge->tensor() != nullptr)
132          {
133              edge->tensor()->unbind_edge(eid);
134          }
135  
136          // Remove edges from source node
137          if(edge->producer() != nullptr)
138          {
139              edge->producer()->_output_edges.erase(eid);
140          }
141  
142          // Remove edges from sink node
143          if((edge->consumer() != nullptr) && (edge->consumer_idx() < edge->consumer()->_input_edges.size()))
144          {
145              edge->consumer()->_input_edges[edge->consumer_idx()] = EmptyEdgeID;
146          }
147      }
148  
149      // Clear edge
150      edge = nullptr;
151  
152      return true;
153  }
154  
create_tensor(const TensorDescriptor & desc)155  TensorID Graph::create_tensor(const TensorDescriptor &desc)
156  {
157      TensorID tid    = _tensors.size();
158      auto     tensor = std::make_unique<Tensor>(tid, desc);
159      _tensors.push_back(std::move(tensor));
160  
161      return tid;
162  }
163  
name() const164  std::string Graph::name() const
165  {
166      return _name;
167  }
168  
id() const169  GraphID Graph::id() const
170  {
171      return _id;
172  }
173  
nodes(NodeType type)174  const std::vector<NodeID> &Graph::nodes(NodeType type)
175  {
176      return _tagged_nodes[type];
177  }
178  
nodes()179  std::vector<std::unique_ptr<INode>> &Graph::nodes()
180  {
181      return _nodes;
182  }
183  
nodes() const184  const std::vector<std::unique_ptr<INode>> &Graph::nodes() const
185  {
186      return _nodes;
187  }
188  
edges() const189  const std::vector<std::unique_ptr<Edge>> &Graph::edges() const
190  {
191      return _edges;
192  }
193  
tensors()194  std::vector<std::unique_ptr<Tensor>> &Graph::tensors()
195  {
196      return _tensors;
197  }
198  
tensors() const199  const std::vector<std::unique_ptr<Tensor>> &Graph::tensors() const
200  {
201      return _tensors;
202  }
203  
node(NodeID id) const204  const INode *Graph::node(NodeID id) const
205  {
206      return (id >= _nodes.size()) ? nullptr : _nodes[id].get();
207  }
208  
node(NodeID id)209  INode *Graph::node(NodeID id)
210  {
211      return (id >= _nodes.size()) ? nullptr : _nodes[id].get();
212  }
213  
edge(EdgeID id) const214  const Edge *Graph::edge(EdgeID id) const
215  {
216      return (id >= _edges.size()) ? nullptr : _edges[id].get();
217  }
218  
edge(EdgeID id)219  Edge *Graph::edge(EdgeID id)
220  {
221      return (id >= _edges.size()) ? nullptr : _edges[id].get();
222  }
223  
tensor(TensorID id) const224  const Tensor *Graph::tensor(TensorID id) const
225  {
226      return (id >= _tensors.size()) ? nullptr : _tensors[id].get();
227  }
228  
tensor(TensorID id)229  Tensor *Graph::tensor(TensorID id)
230  {
231      return (id >= _tensors.size()) ? nullptr : _tensors[id].get();
232  }
233  } // namespace graph
234  } // namespace arm_compute