xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/tests/auto_clustering_test_helper.h"
17 
18 #include "absl/strings/numbers.h"
19 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
20 #include "tensorflow/compiler/jit/xla_cluster_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/io/random_inputstream.h"
26 #include "tensorflow/core/lib/io/zlib_compression_options.h"
27 #include "tensorflow/core/lib/io/zlib_inputstream.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/platform/test_benchmark.h"
31 #include "tensorflow/core/util/port.h"
32 #include "tensorflow/tools/optimization/optimization_pass_runner.h"
33 
34 namespace tensorflow {
35 namespace {
SummarizeClustering(const GraphDef & auto_clustered_graph_def)36 StatusOr<string> SummarizeClustering(const GraphDef& auto_clustered_graph_def) {
37   testing::ResetClusterSequenceNumber();
38   Graph graph(OpRegistry::Global());
39   GraphConstructorOptions graph_opts;
40   graph_opts.expect_device_spec = true;
41   graph_opts.allow_internal_ops = true;
42   TF_RETURN_IF_ERROR(
43       ConvertGraphDefToGraph(graph_opts, auto_clustered_graph_def, &graph));
44 
45   // cluster_id -> (operation name -> # of operations)
46   const int kNoCluster = -1;
47   std::map<int, std::map<string, int>> clusters;
48   std::map<int, int> cluster_size;
49   int clustered_nodes = 0;
50   for (Node* n : graph.op_nodes()) {
51     int cluster = kNoCluster;
52     if (std::optional<absl::string_view> maybe_cluster =
53             GetXlaClusterForNode(*n)) {
54       maybe_cluster->remove_prefix(absl::string_view("cluster_").size());
55       TF_RET_CHECK(absl::SimpleAtoi(*maybe_cluster, &cluster));
56       clustered_nodes++;
57     }
58     clusters[cluster][n->type_string()]++;
59     cluster_size[cluster]++;
60   }
61 
62   string result =
63       absl::StrCat("Clustered nodes: ", clustered_nodes,
64                    "\nUnclustered nodes: ", cluster_size[kNoCluster],
65                    "\nNumber of clusters: ", clusters.size() - 1, "\n\n");
66   for (const auto& pair : clusters) {
67     if (pair.first == kNoCluster) {
68       absl::StrAppend(&result, "unclustered");
69     } else {
70       absl::StrAppend(&result, "cluster ", pair.first);
71     }
72 
73     absl::StrAppend(&result, " size ", cluster_size[pair.first], "\n");
74 
75     for (const auto& ops_and_counts : pair.second) {
76       absl::StrAppend(&result, " ", ops_and_counts.first, " ",
77                       ops_and_counts.second, "\n");
78     }
79   }
80 
81   return result;
82 }
83 
AssertGraphDefIsUnclustered(const GraphDef & graphdef)84 Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) {
85   const char* kXlaClusterAttr = "_XlaCluster";
86   const char* kXlaAlreadyClusteredAttr = "_XlaAlreadyClustered";
87 
88   for (const NodeDef& node : graphdef.node()) {
89     if (node.attr().count(kXlaClusterAttr) ||
90         node.attr().count(kXlaAlreadyClusteredAttr)) {
91       return errors::InvalidArgument(
92           "Input files are already clustered, you probably copied in "
93           "mark_for_compilation_<n>.pbtxt when you should have copied in "
94           "before_mark_for_compilation_<n>.pbtxt");
95     }
96   }
97 
98   return OkStatus();
99 }
100 
ReadTextProtoFromString(Env * env,const string & data,::tensorflow::protobuf::Message * proto)101 Status ReadTextProtoFromString(Env* env, const string& data,
102                                ::tensorflow::protobuf::Message* proto) {
103   if (!::tensorflow::protobuf::TextFormat::ParseFromString(data, proto)) {
104     return errors::DataLoss("Can't parse input data as text proto");
105   }
106   return OkStatus();
107 }
108 }  // namespace
109 
RunAutoClusteringTestImpl(GraphDef graphdef,absl::string_view golden_summary_file_path)110 Status AutoClusteringTest::RunAutoClusteringTestImpl(
111     GraphDef graphdef, absl::string_view golden_summary_file_path) {
112   if (!IsGoogleCudaEnabled()) {
113     // There is some slight change in the clustering decisions under
114     // --config=cuda.  I have not looked closely at why that is happening, but
115     // most likely some of the partial declustering passes behave differently
116     // with --config=cuda because of different HostMemory.  So for now only test
117     // the non-CUDA config, under the assumption that regressions with
118     // --config=cuda would also be detected as regressions without
119     // --config=cuda.
120 
121     LOG(INFO) << "Not running "
122               << ::testing::UnitTest::GetInstance()->current_test_info()->name()
123               << " since test was not built with --config=cuda";
124     return OkStatus();
125   }
126 
127   TF_RETURN_IF_ERROR(AssertGraphDefIsUnclustered(graphdef));
128 
129   OptimizationPassRunner runner;
130   TF_RETURN_IF_ERROR(runner.SetJitLevel(tensorflow::OptimizerOptions::ON_2));
131   TF_RETURN_IF_ERROR(runner.AddCpus(32));
132   TF_RETURN_IF_ERROR(runner.AddGpus(8));
133 
134   for (absl::string_view auto_clustering_pass :
135        {"CloneConstantsForBetterClusteringPass", "MarkForCompilationPass",
136         "IncreaseDynamismForAutoJitPass", "PartiallyDeclusterPass"}) {
137     GraphDef next;
138     TF_RETURN_IF_ERROR(
139         runner.Run(auto_clustering_pass, std::move(graphdef), &next));
140     graphdef = std::move(next);
141   }
142 
143   TF_ASSIGN_OR_RETURN(string clustering_summary, SummarizeClustering(graphdef));
144 
145   // To update golden files flip this to true and run
146   //
147   // bazel test --test_strategy=local \
148   //   tensorflow/compiler/jit/tests:auto_clustering_test
149   bool update_golden = false;
150   if (update_golden) {
151     TF_RETURN_IF_ERROR(WriteStringToFile(
152         Env::Default(), string(golden_summary_file_path), clustering_summary));
153   }
154 
155   string golden_file_contents;
156   TF_RETURN_IF_ERROR(ReadFileToString(
157       Env::Default(), string(golden_summary_file_path), &golden_file_contents));
158 
159   EXPECT_EQ(golden_file_contents, clustering_summary);
160 
161   return OkStatus();
162 }
163 
RunAutoClusteringTestWithPbtxt(absl::string_view pbtxt_file_path,absl::string_view golden_summary_file_path)164 Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt(
165     absl::string_view pbtxt_file_path,
166     absl::string_view golden_summary_file_path) {
167   GraphDef graphdef;
168   TF_RETURN_IF_ERROR(
169       ReadTextProto(Env::Default(), string(pbtxt_file_path), &graphdef));
170   return RunAutoClusteringTestImpl(std::move(graphdef),
171                                    golden_summary_file_path);
172 }
173 
RunAutoClusteringTestWithGzippedPbtxt(absl::string_view gzipped_pbtxt_file_path,absl::string_view golden_summary_file_path)174 Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt(
175     absl::string_view gzipped_pbtxt_file_path,
176     absl::string_view golden_summary_file_path) {
177   Env* env = Env::Default();
178   std::unique_ptr<RandomAccessFile> file_reader;
179   TF_RETURN_IF_ERROR(
180       env->NewRandomAccessFile(string(gzipped_pbtxt_file_path), &file_reader));
181   std::unique_ptr<io::RandomAccessInputStream> input_stream(
182       new io::RandomAccessInputStream(file_reader.get()));
183   constexpr int k_buffer_size = 256 << 10;  // 256kb
184   io::ZlibInputStream in(input_stream.get(),
185                          /*input_buffer_bytes=*/k_buffer_size,
186                          /*output_buffer_bytes=*/k_buffer_size,
187                          io::ZlibCompressionOptions::GZIP());
188   tstring decompressed_pbtxt_string;
189   Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string);
190   if (!s.ok() && !errors::IsOutOfRange(s)) {
191     // OutOfRange is fine since we set the number of read bytes to INT_MAX.
192     // Only return other kinds of errors.
193     return s;
194   }
195 
196   GraphDef graphdef;
197   TF_RETURN_IF_ERROR(ReadTextProtoFromString(
198       Env::Default(), decompressed_pbtxt_string, &graphdef));
199   return RunAutoClusteringTestImpl(std::move(graphdef),
200                                    golden_summary_file_path);
201 }
202 
203 #if defined(PLATFORM_GOOGLE)
BenchmarkMarkForCompilation(absl::string_view graph_def_path,benchmark::State & state)204 Status BenchmarkMarkForCompilation(absl::string_view graph_def_path,
205                                    benchmark::State& state) {
206   GraphDef graph_def;
207   TF_RETURN_IF_ERROR(
208       ReadTextProto(Env::Default(), string(graph_def_path), &graph_def));
209 
210   OptimizationPassRunner runner;
211   TF_RETURN_IF_ERROR(runner.SetJitLevel(tensorflow::OptimizerOptions::ON_2));
212   TF_RETURN_IF_ERROR(runner.AddCpus(32));
213   TF_RETURN_IF_ERROR(runner.AddGpus(8));
214 
215   for (auto _ : state) {
216     state.PauseTiming();
217     GraphDef result;
218     GraphDef graph_def_copy = graph_def;
219     state.ResumeTiming();
220     TF_RETURN_IF_ERROR(runner.Run("MarkForCompilationPass",
221                                   std::move(graph_def_copy), &result));
222   }
223 
224   return OkStatus();
225 }
226 #endif  // PLATFORM_GOOGLE
227 
228 }  // namespace tensorflow
229