xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/memory_optimizer_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Tests for the swig wrapper tf_optimizer."""
16 
17 from tensorflow.core.framework import attr_value_pb2
18 from tensorflow.core.protobuf import config_pb2
19 from tensorflow.core.protobuf import rewriter_config_pb2
20 from tensorflow.python.client import session
21 from tensorflow.python.framework import meta_graph
22 from tensorflow.python.framework import ops
23 from tensorflow.python.framework import random_seed
24 from tensorflow.python.framework import test_util
25 from tensorflow.python.grappler import tf_optimizer
26 from tensorflow.python.ops import math_ops
27 from tensorflow.python.ops import nn
28 from tensorflow.python.ops import variable_scope
29 from tensorflow.python.ops import variables
30 from tensorflow.python.platform import test
31 from tensorflow.python.training import training as train
32 
33 
34 class MemoryOptimizerSwapTest(test.TestCase):
35   """Tests the Grappler memory optimizer."""
36 
37   @test_util.run_deprecated_v1
38   def testNoSwapping(self):
39     """Make sure the graph is preserved when there is nothing to swap."""
40     a = variables.VariableV1(10, name='a')
41     b = variables.VariableV1(20, name='b')
42     c = math_ops.add_n([a, b], name='c')
43     d = math_ops.add_n([b, c], name='d')
44     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
45     train_op.append(d)
46     mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
47     graph_size = len(mg.graph_def.node)
48     nodes = [node.name for node in mg.graph_def.node]
49 
50     config = config_pb2.ConfigProto()
51     config.graph_options.rewrite_options.CopyFrom(
52         rewriter_config_pb2.RewriterConfig(
53             disable_model_pruning=True,
54             constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
55             dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
56             memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
57     graph = tf_optimizer.OptimizeGraph(config, mg)
58 
59     self.assertEqual(len(graph.node), graph_size)
60     self.assertItemsEqual([node.name for node in graph.node], nodes)
61 
62   @test_util.run_v1_only('b/120545219')
63   def testSimpleSwap(self):
64     """Check that the swap annotations are followed."""
65     with ops.device('/gpu:0'):
66       a = variables.VariableV1(10, name='a')
67       b = variables.VariableV1(20, name='b')
68       c = math_ops.add_n([a, b], name='c')
69       d = math_ops.add_n([b, c], name='d')
70       train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
71       train_op.append(d)
72 
73       d.op._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0))
74 
75       mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
76       graph_size = len(mg.graph_def.node)
77 
78       config = config_pb2.ConfigProto()
79       config.graph_options.rewrite_options.CopyFrom(
80           rewriter_config_pb2.RewriterConfig(
81               disable_model_pruning=True,
82               meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE,
83               constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
84               memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL,
85               min_graph_nodes=-1))
86       graph = tf_optimizer.OptimizeGraph(config, mg)
87 
88       self.assertEqual(len(graph.node), graph_size + 2)
89       self.assertTrue(
90           set(node.name for node in graph.node) > set(
91               ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
92       for node in graph.node:
93         if node.name == 'swap_in_d_0':
94           self.assertEqual('swap_out_d_0', node.input[0])
95           self.assertEqual('^b/read', node.input[1])
96         elif node.name == 'swap_out_d_0':
97           self.assertEqual('b/read', node.input[0])
98         elif node.name == 'd':
99           self.assertEqual('swap_in_d_0', node.input[0])
100           self.assertEqual('c', node.input[1])
101 
102 
103 class MemoryOptimizerRecomputeTest(test.TestCase):
104   """Tests the Python interface to recomputation rewrites.
105 
106   See core/grappler/optimizers/memory_optimizer_test.cc for functional tests.
107   """
108 
109   def _GetMetaGraph(self, batch_size=14, image_dim=12, optimizer_scope_name=''):
110     """A simple layered graph with conv, an intermediate op, and a ReLU."""
111     graph = ops.Graph()
112     with graph.as_default():
113       random_seed.set_random_seed(1)
114       current_activation = variable_scope.get_variable(
115           name='start', shape=[batch_size, image_dim, image_dim, 5])
116       conv_filter = variable_scope.get_variable(
117           name='filter', shape=[5, 5, 5, 5])
118       for layer_number in range(10):
119         with variable_scope.variable_scope('layer_{}'.format(layer_number)):
120           after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1],
121                                  'SAME')
122           current_activation = 2. * after_conv
123           current_activation = nn.relu(current_activation)
124       loss = math_ops.reduce_mean(current_activation)
125       with ops.name_scope(optimizer_scope_name):
126         optimizer = train.AdamOptimizer(0.001)
127         train_op = optimizer.minimize(loss)
128       init_op = variables.global_variables_initializer()
129       metagraph = train.export_meta_graph()
130     return (metagraph, init_op.name, train_op.name, loss.name)
131 
132   def testRewritingDefaultGradientNames(self):
133     """Tests that rewriting occurs with default gradient names."""
134     (original_metagraph, _, _, _) = self._GetMetaGraph()
135     config = config_pb2.ConfigProto()
136     config.graph_options.rewrite_options.CopyFrom(
137         rewriter_config_pb2.RewriterConfig(
138             disable_model_pruning=True,
139             constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
140             dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
141             layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
142             arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
143             min_graph_nodes=-1,
144             memory_optimization=(
145                 rewriter_config_pb2.RewriterConfig.RECOMPUTATION_HEURISTICS)))
146     rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph)
147     self.assertGreater(
148         len(rewritten_graph_def.node),
149         len(original_metagraph.graph_def.node))
150     self.assertEqual(
151         0,
152         len([node for node in original_metagraph.graph_def.node
153              if 'Recomputed/' in node.name]))
154     self.assertEqual(
155         20,  # Two per layer
156         len([node for node in rewritten_graph_def.node
157              if 'Recomputed/' in node.name]))
158 
159   def testRewritingNameScopedGradientNames(self):
160     """Tests that rewriting occurs with non-standard gradient names."""
161     (original_metagraph, _, _, _) = self._GetMetaGraph(
162         optimizer_scope_name='optimizer')
163     config = config_pb2.ConfigProto()
164     config.graph_options.rewrite_options.CopyFrom(
165         rewriter_config_pb2.RewriterConfig(
166             disable_model_pruning=True,
167             constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
168             dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
169             layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
170             arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
171             min_graph_nodes=-1,
172             memory_optimization=rewriter_config_pb2.RewriterConfig
173             .RECOMPUTATION_HEURISTICS,
174             # Checks that name scope "gradients/" also match sub-scope.
175             memory_optimizer_target_node_name_scope='gradients/'))
176     rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph)
177     self.assertGreater(
178         len(rewritten_graph_def.node),
179         len(original_metagraph.graph_def.node))
180     self.assertEqual(
181         0,
182         len([node for node in original_metagraph.graph_def.node
183              if 'Recomputed/' in node.name]))
184     self.assertEqual(
185         20,  # Two per layer
186         len([node for node in rewritten_graph_def.node
187              if 'Recomputed/' in node.name]))
188 
189   def testRewritingNameScopedGradientNamesScope(self):
190     """Tests that rewriting occurs with non-standard gradient names."""
191     (original_metagraph, _, _,
192      _) = self._GetMetaGraph(optimizer_scope_name='foo/bar')
193     config = config_pb2.ConfigProto()
194     config.graph_options.rewrite_options.CopyFrom(
195         rewriter_config_pb2.RewriterConfig(
196             disable_model_pruning=True,
197             constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
198             dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
199             layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
200             arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
201             memory_optimization=rewriter_config_pb2.RewriterConfig
202             .RECOMPUTATION_HEURISTICS,
203             # This should not match anything.
204             memory_optimizer_target_node_name_scope='r/gradients/'))
205     rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph)
206     self.assertEqual(
207         len(rewritten_graph_def.node), len(original_metagraph.graph_def.node))
208     self.assertEqual(0,
209                      len([
210                          node for node in original_metagraph.graph_def.node
211                          if 'Recomputed/' in node.name
212                      ]))
213     self.assertEqual(0,
214                      len([
215                          node for node in rewritten_graph_def.node
216                          if 'Recomputed/' in node.name
217                      ]))
218 
219   def _GetMemoryOptimizerSessionConfig(self):
220     rewrite_options = rewriter_config_pb2.RewriterConfig(
221         disable_model_pruning=True,
222         memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
223     graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
224     return config_pb2.ConfigProto(graph_options=graph_options)
225 
226   def _RunMetaGraphWithConfig(
227       self, config, metagraph, init_op_name, train_op_name, loss_op_name):
228     graph = ops.Graph()
229     with graph.as_default():
230       train.import_meta_graph(metagraph)
231       init_op = graph.get_operation_by_name(init_op_name)
232       train_op = graph.get_operation_by_name(train_op_name)
233       loss_op = graph.get_tensor_by_name(loss_op_name)
234       with session.Session(config=config, graph=graph) as sess:
235         self.evaluate(init_op)
236         self.evaluate(train_op)
237         self.evaluate(train_op)
238         return self.evaluate(loss_op)
239 
240   def testRecomputationRewritingNoErrors(self):
241     """Tests that graph output is not significantly different with rewriting."""
242     (original_metagraph, init_op_name, train_op_name, loss_op_name
243     ) = self._GetMetaGraph()
244     original_loss = self._RunMetaGraphWithConfig(
245         config=config_pb2.ConfigProto(),
246         metagraph=original_metagraph,
247         init_op_name=init_op_name,
248         train_op_name=train_op_name,
249         loss_op_name=loss_op_name)
250     memory_optimized_loss = self._RunMetaGraphWithConfig(
251         config=self._GetMemoryOptimizerSessionConfig(),
252         metagraph=original_metagraph,
253         init_op_name=init_op_name,
254         train_op_name=train_op_name,
255         loss_op_name=loss_op_name)
256     self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2)
257 
258   def _annotated_graph(self):
259     graph = ops.Graph()
260     with graph.as_default():
261       random_seed.set_random_seed(2)
262       current_activation = variable_scope.get_variable(
263           name='start', shape=[1, 2, 2, 5])
264       conv_filter = variable_scope.get_variable(
265           name='filter', shape=[5, 5, 5, 5])
266       for layer_number in range(3):
267         with variable_scope.variable_scope('layer_{}'.format(layer_number)):
268           after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1],
269                                  'SAME')
270           current_activation = 2. * after_conv
271           current_activation.op._set_attr(
272               '_recompute_hint',
273               # The value of the attribute does not matter; just that the key
274               # exists in the op's attributes.
275               attr_value_pb2.AttrValue(i=1))
276           current_activation += 5.
277           current_activation.op._set_attr(
278               '_recompute_hint', attr_value_pb2.AttrValue(i=0))
279           current_activation = nn.relu(current_activation)
280           current_activation.op._set_attr(
281               '_recompute_hint', attr_value_pb2.AttrValue(i=1))
282       loss = math_ops.reduce_mean(current_activation)
283       optimizer = train.AdamOptimizer(0.001)
284       train_op = optimizer.minimize(loss)
285       init_op = variables.global_variables_initializer()
286     return graph, init_op, train_op
287 
288   def testHintNoMetaGraph(self):
289     # Closer to expected usage, but does not check that a re-write actually
290     # happens; see testHintDoesRewrite.
291     graph, init_op, train_op = self._annotated_graph()
292     with graph.as_default():
293       manual_memory_config = rewriter_config_pb2.RewriterConfig(
294           memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
295       graph_options = config_pb2.GraphOptions(
296           rewrite_options=manual_memory_config)
297       session_config = config_pb2.ConfigProto(graph_options=graph_options)
298       with session.Session(config=session_config) as sess:
299         self.evaluate(init_op)
300         self.evaluate(train_op)
301 
302   @test_util.run_v1_only('b/120545219')
303   def testHintDoesRewrite(self):
304     graph = self._annotated_graph()[0]
305     with graph.as_default():
306       metagraph = train.export_meta_graph()
307     self.assertEqual(
308         0,
309         len([node for node in metagraph.graph_def.node
310              if 'Recomputed/' in node.name]))
311     config = config_pb2.ConfigProto()
312     config.graph_options.rewrite_options.CopyFrom(
313         rewriter_config_pb2.RewriterConfig(
314             min_graph_nodes=-1,
315             memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
316     rewritten_graph_def = tf_optimizer.OptimizeGraph(config, metagraph)
317     self.assertEqual(
318         9,
319         len([
320             node for node in rewritten_graph_def.node
321             if 'Recomputed/' in node.name
322         ]))
323 
324 if __name__ == '__main__':
325   test.main()
326