1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
18
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/mutable_graph_view.h"
29 #include "tensorflow/core/grappler/utils.h"
30 #include "tensorflow/core/lib/core/errors.h"
31
32 namespace tensorflow {
33 namespace grappler {
34 namespace graph_utils {
35
36 // Returns the index of the first element in collection that fulfills predicate.
37 // If no such element exists, returns -1.
38 template <typename Predicate, typename Collection>
GetFirstElementIndexWithPredicate(const Predicate & predicate,const Collection & collection)39 int GetFirstElementIndexWithPredicate(const Predicate& predicate,
40 const Collection& collection) {
41 unsigned idx = 0;
42 for (auto&& element : collection) {
43 if (predicate(element)) {
44 return idx;
45 }
46 idx++;
47 }
48 return -1;
49 }
50
51 // Adds a node to the graph.
52 NodeDef* AddNode(StringPiece name, StringPiece op,
53 const std::vector<string>& inputs,
54 const std::vector<std::pair<string, AttrValue>>& attributes,
55 MutableGraphView* graph);
56
57 // Adds Placeholder node for given type.
58 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);
59
60 // Adds a Const node with the given value to the graph.
61 template <typename T>
AddScalarConstNode(T v,MutableGraphView * graph)62 NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
63 // is_same is an idiomatic hack for making it compile if not instantiated.
64 // Replacing with false will result in a compile-time error.
65 static_assert(!std::is_same<T, T>::value,
66 "Invalid specialization of this method for type T.");
67 return {};
68 }
69
70 template <>
71 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph);
72 template <>
73 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph);
74 template <>
75 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph);
76 template <>
77 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph);
78 template <>
79 NodeDef* AddScalarConstNode(int64_t v, MutableGraphView* graph);
80 template <>
81 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph);
82
83 // Retrieves the value of a const node. Returns an error
84 // if the node is not const, or its value is of a different type.
85 template <typename T>
GetScalarConstNodeValue(const NodeDef & node,T * value)86 Status GetScalarConstNodeValue(const NodeDef& node, T* value) {
87 // is_same is an idiomatic hack for making it compile if not instantiated.
88 // Replacing with false will result in a compile-time error.
89 static_assert(!std::is_same<T, T>::value,
90 "Invalid specialization of this method fo rtype T.");
91 }
92
93 template <>
94 Status GetScalarConstNodeValue(const NodeDef& node, int64_t* value);
95 template <>
96 Status GetScalarConstNodeValue(const NodeDef& node, bool* value);
97
98 // Checks whether the two graphs are the same.
99 bool Compare(const GraphDef& g1, const GraphDef& g2);
100
101 // Checks whether the graph contains a node with the given name.
102 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);
103
104 // Checks whether the library contains a function with the given name.
105 bool ContainsGraphFunctionWithName(StringPiece name,
106 const FunctionDefLibrary& library);
107
108 // Checks whether the graph contains a node with the given op.
109 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
110
111 // Returns the index of the node with the given name or -1 if the node does
112 // not exist.
113 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);
114
115 // Returns the index of the function with the given name or -1 if the function
116 // does not exist.
117 int FindGraphFunctionWithName(StringPiece name,
118 const FunctionDefLibrary& library);
119
120 // Returns the index of the first node with the given op or -1 if no such node
121 // exists.
122 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
123
124 // Gets the 0th input to a node in the graph.
125 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
126
127 // Gets the ith input to a node in the graph.
128 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
129 int64_t i);
130
131 // Gets the attr corresponding to a dataset node's output types, if it exists.
132 Status GetDatasetOutputTypesAttr(const NodeDef& node,
133 DataTypeVector* output_types);
134
135 // Returns the list of indices of all nodes with the given op or empty list if
136 // no such node exists.
137 std::vector<int> FindAllGraphNodesWithOp(const string& op,
138 const GraphDef& graph);
139
140 // Sets the node name using `prefix` as a prefix while guaranteeing the name
141 // is unique across the graph.
142 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
143
144 // Sets the function name using the `prefix` name as a prefix while guaranteeing
145 // the name is unique across the function library.
146 void SetUniqueGraphFunctionName(StringPiece prefix,
147 const FunctionDefLibrary* library,
148 FunctionDef* function);
149
150 // Copies attribute having name `attribute_name` from node `from` to node
151 // `to_node`.
152 void CopyAttribute(const string& attribute_name, const NodeDef& from,
153 NodeDef* to_node);
154
155 // Concatenates list attribute having name `attribute_name` from `first` and
156 // `second` node, setting it to `to_node`.
157 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
158 const NodeDef& second, NodeDef* to_node);
159
160 // Checks that all nodes in the graphs have unique names, and sets their names
161 // to be unique if they are not already. This is necessary as Graph does not
162 // have the provisions to deduplicate names, and name deduplication elsewhere
163 // in tensorflow happens in other layers (for example, in the Scope class of the
164 // C++ API). Note that the nodes in the graph are identified by their id,
165 // and renaming nodes does not mutate any edges.
166 Status EnsureNodeNamesUnique(Graph* g);
167
168 // Returns the item's fetch node, if there is exactly one. Otherwise, returns an
169 // error.
170 Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item,
171 NodeDef** fetch_node);
172
173 // Returns true if `item` is derived from a `FunctionDef`, false otherwise.
174 // Currently, we determine this heuristically: If we don't have any fetch nodes
175 // or all fetch nodes are `Retval` ops, then we consider this item as derived
176 // from a `FunctionDef`.
177 bool IsItemDerivedFromFunctionDef(const GrapplerItem& item,
178 const MutableGraphView& graph_view);
179
180 // If both input nodes have the "metadata" attribute set, it populates the
181 // "metadata" attribute for the fused node.
182 void MaybeSetFusedMetadata(const NodeDef& node1, const NodeDef& node2,
183 NodeDef* fused_node);
184
185 // Copies the attributes `output_shapes`, `output_types` from node `from` to
186 // node `to_node` if they exist. The method will return `true` if attributes
187 // copied successfully, otherwise it will return `false`.
188 //
189 // Some tf.data transformations set `Toutput_types` instead of `output_types`
190 // when the attribute describes type of tensor inputs (e.g. TensorDataset,
191 // TensorSliceDataset, and PaddedBatchDataset). In this case the method copies
192 // the attribute `Toutput_types` of node `from` to the attribute `output_types`
193 // of node `to_node`.
194 bool CopyShapesAndTypesAttrs(const NodeDef& from, NodeDef* to_node);
195
196 // Checks whether the op has a "sloppy" attribute.
197 bool HasSloppyAttr(const string& op);
198
199 // Checks whether the op has a "replicate_on_split" attribute.
200 bool HasReplicateOnSplitAttr(const string& op);
201
202 // Checks whether the op has a "deterministic" attribute.
203 bool HasDeterministicAttr(const string& op);
204
205 // Sets the `name` as the metadata name of the `node`. It returns an error if
206 // the `node` already has a metadata name.
207 Status SetMetadataName(const std::string& name, NodeDef* node);
208
209 } // namespace graph_utils
210 } // namespace grappler
211 } // namespace tensorflow
212
213 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
214