1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
17
18 #include "absl/strings/str_split.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/metrics.h"
22 #include "tensorflow/core/framework/versions.pb.h"
23 #include "tensorflow/core/grappler/clusters/cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/utils/functions.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/util/ptr_util.h"
29
30 namespace tensorflow {
31 namespace grappler {
32
33 namespace {
34
35 using ConfigMap =
36 std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
37
38 // tf.data optimizations, in the order we want to perform them.
39 constexpr std::array<const char*, 19> kTFDataOptimizations = {
40 "noop_elimination",
41 "disable_intra_op_parallelism",
42 "use_private_thread_pool",
43 "shuffle_and_repeat_fusion",
44 "map_fusion",
45 "filter_fusion",
46 "map_and_filter_fusion",
47 "map_parallelization",
48 "map_and_batch_fusion",
49 "batch_parallelization",
50 "filter_parallelization",
51 "make_sloppy",
52 "parallel_batch",
53 "slack",
54 "autotune_buffer_sizes",
55 "inject_prefetch",
56 "disable_prefetch_legacy_autotune",
57 "enable_gradient_descent",
58 "make_deterministic"};
59
60 // Parses a list of string optimizer configurations into a map from
61 // optimizer name -> rewriter config for that optimizer.
ToConfigMap(const tensorflow::RewriterConfig_CustomGraphOptimizer * config,ConfigMap * result)62 Status ToConfigMap(
63 const tensorflow::RewriterConfig_CustomGraphOptimizer* config,
64 ConfigMap* result) {
65 auto found = gtl::FindOrNull(config->parameter_map(), "optimizer_configs");
66 if (!found) return OkStatus();
67
68 auto& options = found->list().s();
69 for (const auto& option_string : options) {
70 // The option string has the format
71 // <optimizer_name>:<config_key>:<config_value>
72 std::vector<string> split = absl::StrSplit(option_string, ':');
73 if (split.size() != 3) {
74 return errors::Internal(
75 "Wrong format for optimizer options. Expect <optimizer name>:<config "
76 "key>:<config value>, received: ",
77 option_string);
78 }
79
80 const string& optimizer_name = split[0];
81 const string& config_key = split[1];
82 const string& config_value = split[2];
83
84 auto optimizer_config = gtl::FindOrNull(*result, optimizer_name);
85 if (!optimizer_config) {
86 (*result)[optimizer_name] =
87 tensorflow::RewriterConfig_CustomGraphOptimizer();
88 optimizer_config = gtl::FindOrNull(*result, optimizer_name);
89 }
90 (*optimizer_config->mutable_parameter_map())[config_key].set_s(
91 config_value);
92 }
93
94 return OkStatus();
95 }
96
97 } // namespace
98
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)99 Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
100 GraphDef* output) {
101 // Stores the optimized item so far.
102 GrapplerItem optimized_item = item;
103
104 // Perform optimizations in a meaningful order.
105 for (const auto& optimization : kTFDataOptimizations) {
106 tensorflow::metrics::ScopedCounter<2> timings(
107 tensorflow::metrics::GetGraphOptimizationCounter(),
108 {"TFData", optimization});
109 Status status = ApplyOptimization(optimization, cluster, &optimized_item);
110 timings.ReportAndStop();
111 if (!status.ok()) return status;
112 }
113
114 // Store the final result of all the optimizations in `output`.
115 output->Swap(&optimized_item.graph);
116
117 // Optimize tf.data user-defined functions.
118 FunctionLibraryDefinition flib =
119 FunctionLibraryDefinition(OpRegistry::Global(), output->library())
120 .ReachableDefinitions(*output);
121 const auto producer = output->versions().producer();
122 bool optimized_functions = false;
123 for (const auto& name : flib.ListFunctionNames()) {
124 auto* func = flib.Find(name);
125 // Skip non tf.data functions.
126 if (!data::IsTFDataFunction(*func)) continue;
127 VLOG(3) << "Optimize function: function=" << func->signature().name();
128 optimized_functions = true;
129
130 // Make a GrapplerItem from a FunctionDef.
131 GrapplerFunctionItem func_item;
132 TF_RETURN_IF_ERROR(
133 MakeGrapplerFunctionItem(*func, flib, producer, &func_item));
134
135 GraphDef optimized_func_graph;
136 TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
137
138 // Function body optimization might have created new functions. Add them to
139 // the library.
140 for (const FunctionDef& func_def :
141 optimized_func_graph.library().function()) {
142 if (flib.Find(func_def.signature().name()) == nullptr) {
143 TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
144 }
145 }
146
147 // Convert optimized graph back to FunctionDef.
148 FunctionDef optimized_func;
149 func_item.SwapFunctionBody(std::move(optimized_func_graph));
150 TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
151
152 // Replace optimized function with a new FunctionDef.
153 TF_RETURN_IF_ERROR(
154 flib.ReplaceFunction(func->signature().name(), optimized_func));
155 }
156 if (optimized_functions) {
157 *output->mutable_library() = flib.ToProto();
158 }
159 return OkStatus();
160 }
161
ApplyOptimization(const string & name,Cluster * cluster,GrapplerItem * item) const162 Status TFDataMetaOptimizer::ApplyOptimization(const string& name,
163 Cluster* cluster,
164 GrapplerItem* item) const {
165 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
166
167 const auto* optimizer = gtl::FindOrNull(enabled_optimizers_, name);
168 if (!optimizer) {
169 return OkStatus();
170 }
171
172 GraphDef result;
173 (*optimizer)->set_deadline_usec(this->deadline_usec());
174 Status status = (*optimizer)->Optimize(cluster, *item, &result);
175 if (status.ok()) {
176 // The optimizer succeeded and wrote the optimized graph to result.
177 item->graph.Swap(&result);
178 } else if (errors::IsAborted(status)) {
179 // A status of errors::Aborted just means that the optimizer was a no-op and
180 // did not populate result. Swallow the error status and leave the original
181 // graph in item.
182 status = OkStatus();
183 }
184
185 return status;
186 }
187
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)188 Status TFDataMetaOptimizer::Init(
189 const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
190 if (!config) return OkStatus();
191
192 // Initialize custom tf.data optimizers based on config.
193 auto& optimizers = config->parameter_map().at("optimizers").list().s();
194 ConfigMap optimizer_configs;
195 TF_RETURN_IF_ERROR(ToConfigMap(config, &optimizer_configs));
196
197 for (const auto& optimizer_name : optimizers) {
198 auto optimizer =
199 CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
200 if (optimizer) {
201 TF_RETURN_IF_ERROR(
202 optimizer->Init(gtl::FindOrNull(optimizer_configs, optimizer_name)));
203
204 enabled_optimizers_[optimizer_name] = std::move(optimizer);
205 } else {
206 return errors::Internal(
207 "Tried to register a dataset optimizer that doesn't exist: ",
208 optimizer_name);
209 }
210 }
211
212 return OkStatus();
213 }
214
215 REGISTER_GRAPH_OPTIMIZER_AS(TFDataMetaOptimizer, "tf_data_meta_optimizer");
216
217 } // namespace grappler
218 } // namespace tensorflow
219