xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
17 
18 #include "absl/strings/str_split.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/metrics.h"
22 #include "tensorflow/core/framework/versions.pb.h"
23 #include "tensorflow/core/grappler/clusters/cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/utils/functions.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/util/ptr_util.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 
33 namespace {
34 
35 using ConfigMap =
36     std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
37 
38 // tf.data optimizations, in the order we want to perform them.
39 constexpr std::array<const char*, 19> kTFDataOptimizations = {
40     "noop_elimination",
41     "disable_intra_op_parallelism",
42     "use_private_thread_pool",
43     "shuffle_and_repeat_fusion",
44     "map_fusion",
45     "filter_fusion",
46     "map_and_filter_fusion",
47     "map_parallelization",
48     "map_and_batch_fusion",
49     "batch_parallelization",
50     "filter_parallelization",
51     "make_sloppy",
52     "parallel_batch",
53     "slack",
54     "autotune_buffer_sizes",
55     "inject_prefetch",
56     "disable_prefetch_legacy_autotune",
57     "enable_gradient_descent",
58     "make_deterministic"};
59 
60 // Parses a list of string optimizer configurations into a map from
61 // optimizer name -> rewriter config for that optimizer.
ToConfigMap(const tensorflow::RewriterConfig_CustomGraphOptimizer * config,ConfigMap * result)62 Status ToConfigMap(
63     const tensorflow::RewriterConfig_CustomGraphOptimizer* config,
64     ConfigMap* result) {
65   auto found = gtl::FindOrNull(config->parameter_map(), "optimizer_configs");
66   if (!found) return OkStatus();
67 
68   auto& options = found->list().s();
69   for (const auto& option_string : options) {
70     // The option string has the format
71     // <optimizer_name>:<config_key>:<config_value>
72     std::vector<string> split = absl::StrSplit(option_string, ':');
73     if (split.size() != 3) {
74       return errors::Internal(
75           "Wrong format for optimizer options. Expect <optimizer name>:<config "
76           "key>:<config value>, received: ",
77           option_string);
78     }
79 
80     const string& optimizer_name = split[0];
81     const string& config_key = split[1];
82     const string& config_value = split[2];
83 
84     auto optimizer_config = gtl::FindOrNull(*result, optimizer_name);
85     if (!optimizer_config) {
86       (*result)[optimizer_name] =
87           tensorflow::RewriterConfig_CustomGraphOptimizer();
88       optimizer_config = gtl::FindOrNull(*result, optimizer_name);
89     }
90     (*optimizer_config->mutable_parameter_map())[config_key].set_s(
91         config_value);
92   }
93 
94   return OkStatus();
95 }
96 
97 }  // namespace
98 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)99 Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
100                                      GraphDef* output) {
101   // Stores the optimized item so far.
102   GrapplerItem optimized_item = item;
103 
104   // Perform optimizations in a meaningful order.
105   for (const auto& optimization : kTFDataOptimizations) {
106     tensorflow::metrics::ScopedCounter<2> timings(
107         tensorflow::metrics::GetGraphOptimizationCounter(),
108         {"TFData", optimization});
109     Status status = ApplyOptimization(optimization, cluster, &optimized_item);
110     timings.ReportAndStop();
111     if (!status.ok()) return status;
112   }
113 
114   // Store the final result of all the optimizations in `output`.
115   output->Swap(&optimized_item.graph);
116 
117   // Optimize tf.data user-defined functions.
118   FunctionLibraryDefinition flib =
119       FunctionLibraryDefinition(OpRegistry::Global(), output->library())
120           .ReachableDefinitions(*output);
121   const auto producer = output->versions().producer();
122   bool optimized_functions = false;
123   for (const auto& name : flib.ListFunctionNames()) {
124     auto* func = flib.Find(name);
125     // Skip non tf.data functions.
126     if (!data::IsTFDataFunction(*func)) continue;
127     VLOG(3) << "Optimize function: function=" << func->signature().name();
128     optimized_functions = true;
129 
130     // Make a GrapplerItem from a FunctionDef.
131     GrapplerFunctionItem func_item;
132     TF_RETURN_IF_ERROR(
133         MakeGrapplerFunctionItem(*func, flib, producer, &func_item));
134 
135     GraphDef optimized_func_graph;
136     TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
137 
138     // Function body optimization might have created new functions. Add them to
139     // the library.
140     for (const FunctionDef& func_def :
141          optimized_func_graph.library().function()) {
142       if (flib.Find(func_def.signature().name()) == nullptr) {
143         TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
144       }
145     }
146 
147     // Convert optimized graph back to FunctionDef.
148     FunctionDef optimized_func;
149     func_item.SwapFunctionBody(std::move(optimized_func_graph));
150     TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
151 
152     // Replace optimized function with a new FunctionDef.
153     TF_RETURN_IF_ERROR(
154         flib.ReplaceFunction(func->signature().name(), optimized_func));
155   }
156   if (optimized_functions) {
157     *output->mutable_library() = flib.ToProto();
158   }
159   return OkStatus();
160 }
161 
ApplyOptimization(const string & name,Cluster * cluster,GrapplerItem * item) const162 Status TFDataMetaOptimizer::ApplyOptimization(const string& name,
163                                               Cluster* cluster,
164                                               GrapplerItem* item) const {
165   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
166 
167   const auto* optimizer = gtl::FindOrNull(enabled_optimizers_, name);
168   if (!optimizer) {
169     return OkStatus();
170   }
171 
172   GraphDef result;
173   (*optimizer)->set_deadline_usec(this->deadline_usec());
174   Status status = (*optimizer)->Optimize(cluster, *item, &result);
175   if (status.ok()) {
176     // The optimizer succeeded and wrote the optimized graph to result.
177     item->graph.Swap(&result);
178   } else if (errors::IsAborted(status)) {
179     // A status of errors::Aborted just means that the optimizer was a no-op and
180     // did not populate result. Swallow the error status and leave the original
181     // graph in item.
182     status = OkStatus();
183   }
184 
185   return status;
186 }
187 
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)188 Status TFDataMetaOptimizer::Init(
189     const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
190   if (!config) return OkStatus();
191 
192   // Initialize custom tf.data optimizers based on config.
193   auto& optimizers = config->parameter_map().at("optimizers").list().s();
194   ConfigMap optimizer_configs;
195   TF_RETURN_IF_ERROR(ToConfigMap(config, &optimizer_configs));
196 
197   for (const auto& optimizer_name : optimizers) {
198     auto optimizer =
199         CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
200     if (optimizer) {
201       TF_RETURN_IF_ERROR(
202           optimizer->Init(gtl::FindOrNull(optimizer_configs, optimizer_name)));
203 
204       enabled_optimizers_[optimizer_name] = std::move(optimizer);
205     } else {
206       return errors::Internal(
207           "Tried to register a dataset optimizer that doesn't exist: ",
208           optimizer_name);
209     }
210   }
211 
212   return OkStatus();
213 }
214 
215 REGISTER_GRAPH_OPTIMIZER_AS(TFDataMetaOptimizer, "tf_data_meta_optimizer");
216 
217 }  // namespace grappler
218 }  // namespace tensorflow
219