1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the swig wrapper tf_optimizer.""" 16 17 from tensorflow.core.framework import attr_value_pb2 18 from tensorflow.core.protobuf import config_pb2 19 from tensorflow.core.protobuf import rewriter_config_pb2 20 from tensorflow.python.client import session 21 from tensorflow.python.framework import meta_graph 22 from tensorflow.python.framework import ops 23 from tensorflow.python.framework import random_seed 24 from tensorflow.python.framework import test_util 25 from tensorflow.python.grappler import tf_optimizer 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops import nn 28 from tensorflow.python.ops import variable_scope 29 from tensorflow.python.ops import variables 30 from tensorflow.python.platform import test 31 from tensorflow.python.training import training as train 32 33 34 class MemoryOptimizerSwapTest(test.TestCase): 35 """Tests the Grappler memory optimizer.""" 36 37 @test_util.run_deprecated_v1 38 def testNoSwapping(self): 39 """Make sure the graph is preserved when there is nothing to swap.""" 40 a = variables.VariableV1(10, name='a') 41 b = variables.VariableV1(20, name='b') 42 c = math_ops.add_n([a, b], name='c') 43 d = math_ops.add_n([b, c], name='d') 44 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 45 train_op.append(d) 46 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 47 graph_size = len(mg.graph_def.node) 48 nodes = [node.name for node in mg.graph_def.node] 49 50 config = config_pb2.ConfigProto() 51 config.graph_options.rewrite_options.CopyFrom( 52 rewriter_config_pb2.RewriterConfig( 53 disable_model_pruning=True, 54 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 55 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 56 memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)) 57 graph = tf_optimizer.OptimizeGraph(config, mg) 58 59 self.assertEqual(len(graph.node), graph_size) 60 self.assertItemsEqual([node.name for node in graph.node], nodes) 61 62 @test_util.run_v1_only('b/120545219') 63 def testSimpleSwap(self): 64 """Check that the swap annotations are followed.""" 65 with ops.device('/gpu:0'): 66 a = variables.VariableV1(10, name='a') 67 b = variables.VariableV1(20, name='b') 68 c = math_ops.add_n([a, b], name='c') 69 d = math_ops.add_n([b, c], name='d') 70 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 71 train_op.append(d) 72 73 d.op._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0)) 74 75 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 76 graph_size = len(mg.graph_def.node) 77 78 config = config_pb2.ConfigProto() 79 config.graph_options.rewrite_options.CopyFrom( 80 rewriter_config_pb2.RewriterConfig( 81 disable_model_pruning=True, 82 meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE, 83 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 84 memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL, 85 min_graph_nodes=-1)) 86 graph = tf_optimizer.OptimizeGraph(config, mg) 87 88 self.assertEqual(len(graph.node), graph_size + 2) 89 self.assertTrue( 90 set(node.name for node in graph.node) > set( 91 ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0'])) 92 for node in graph.node: 93 if node.name == 'swap_in_d_0': 94 self.assertEqual('swap_out_d_0', node.input[0]) 95 self.assertEqual('^b/read', node.input[1]) 96 elif node.name == 'swap_out_d_0': 97 self.assertEqual('b/read', node.input[0]) 98 elif node.name == 'd': 99 self.assertEqual('swap_in_d_0', node.input[0]) 100 self.assertEqual('c', node.input[1]) 101 102 103 class MemoryOptimizerRecomputeTest(test.TestCase): 104 """Tests the Python interface to recomputation rewrites. 105 106 See core/grappler/optimizers/memory_optimizer_test.cc for functional tests. 107 """ 108 109 def _GetMetaGraph(self, batch_size=14, image_dim=12, optimizer_scope_name=''): 110 """A simple layered graph with conv, an intermediate op, and a ReLU.""" 111 graph = ops.Graph() 112 with graph.as_default(): 113 random_seed.set_random_seed(1) 114 current_activation = variable_scope.get_variable( 115 name='start', shape=[batch_size, image_dim, image_dim, 5]) 116 conv_filter = variable_scope.get_variable( 117 name='filter', shape=[5, 5, 5, 5]) 118 for layer_number in range(10): 119 with variable_scope.variable_scope('layer_{}'.format(layer_number)): 120 after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1], 121 'SAME') 122 current_activation = 2. * after_conv 123 current_activation = nn.relu(current_activation) 124 loss = math_ops.reduce_mean(current_activation) 125 with ops.name_scope(optimizer_scope_name): 126 optimizer = train.AdamOptimizer(0.001) 127 train_op = optimizer.minimize(loss) 128 init_op = variables.global_variables_initializer() 129 metagraph = train.export_meta_graph() 130 return (metagraph, init_op.name, train_op.name, loss.name) 131 132 def testRewritingDefaultGradientNames(self): 133 """Tests that rewriting occurs with default gradient names.""" 134 (original_metagraph, _, _, _) = self._GetMetaGraph() 135 config = config_pb2.ConfigProto() 136 config.graph_options.rewrite_options.CopyFrom( 137 rewriter_config_pb2.RewriterConfig( 138 disable_model_pruning=True, 139 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 140 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 141 layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, 142 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 143 min_graph_nodes=-1, 144 memory_optimization=( 145 rewriter_config_pb2.RewriterConfig.RECOMPUTATION_HEURISTICS))) 146 rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph) 147 self.assertGreater( 148 len(rewritten_graph_def.node), 149 len(original_metagraph.graph_def.node)) 150 self.assertEqual( 151 0, 152 len([node for node in original_metagraph.graph_def.node 153 if 'Recomputed/' in node.name])) 154 self.assertEqual( 155 20, # Two per layer 156 len([node for node in rewritten_graph_def.node 157 if 'Recomputed/' in node.name])) 158 159 def testRewritingNameScopedGradientNames(self): 160 """Tests that rewriting occurs with non-standard gradient names.""" 161 (original_metagraph, _, _, _) = self._GetMetaGraph( 162 optimizer_scope_name='optimizer') 163 config = config_pb2.ConfigProto() 164 config.graph_options.rewrite_options.CopyFrom( 165 rewriter_config_pb2.RewriterConfig( 166 disable_model_pruning=True, 167 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 168 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 169 layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, 170 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 171 min_graph_nodes=-1, 172 memory_optimization=rewriter_config_pb2.RewriterConfig 173 .RECOMPUTATION_HEURISTICS, 174 # Checks that name scope "gradients/" also match sub-scope. 175 memory_optimizer_target_node_name_scope='gradients/')) 176 rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph) 177 self.assertGreater( 178 len(rewritten_graph_def.node), 179 len(original_metagraph.graph_def.node)) 180 self.assertEqual( 181 0, 182 len([node for node in original_metagraph.graph_def.node 183 if 'Recomputed/' in node.name])) 184 self.assertEqual( 185 20, # Two per layer 186 len([node for node in rewritten_graph_def.node 187 if 'Recomputed/' in node.name])) 188 189 def testRewritingNameScopedGradientNamesScope(self): 190 """Tests that rewriting occurs with non-standard gradient names.""" 191 (original_metagraph, _, _, 192 _) = self._GetMetaGraph(optimizer_scope_name='foo/bar') 193 config = config_pb2.ConfigProto() 194 config.graph_options.rewrite_options.CopyFrom( 195 rewriter_config_pb2.RewriterConfig( 196 disable_model_pruning=True, 197 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 198 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 199 layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, 200 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 201 memory_optimization=rewriter_config_pb2.RewriterConfig 202 .RECOMPUTATION_HEURISTICS, 203 # This should not match anything. 204 memory_optimizer_target_node_name_scope='r/gradients/')) 205 rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph) 206 self.assertEqual( 207 len(rewritten_graph_def.node), len(original_metagraph.graph_def.node)) 208 self.assertEqual(0, 209 len([ 210 node for node in original_metagraph.graph_def.node 211 if 'Recomputed/' in node.name 212 ])) 213 self.assertEqual(0, 214 len([ 215 node for node in rewritten_graph_def.node 216 if 'Recomputed/' in node.name 217 ])) 218 219 def _GetMemoryOptimizerSessionConfig(self): 220 rewrite_options = rewriter_config_pb2.RewriterConfig( 221 disable_model_pruning=True, 222 memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS) 223 graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options) 224 return config_pb2.ConfigProto(graph_options=graph_options) 225 226 def _RunMetaGraphWithConfig( 227 self, config, metagraph, init_op_name, train_op_name, loss_op_name): 228 graph = ops.Graph() 229 with graph.as_default(): 230 train.import_meta_graph(metagraph) 231 init_op = graph.get_operation_by_name(init_op_name) 232 train_op = graph.get_operation_by_name(train_op_name) 233 loss_op = graph.get_tensor_by_name(loss_op_name) 234 with session.Session(config=config, graph=graph) as sess: 235 self.evaluate(init_op) 236 self.evaluate(train_op) 237 self.evaluate(train_op) 238 return self.evaluate(loss_op) 239 240 def testRecomputationRewritingNoErrors(self): 241 """Tests that graph output is not significantly different with rewriting.""" 242 (original_metagraph, init_op_name, train_op_name, loss_op_name 243 ) = self._GetMetaGraph() 244 original_loss = self._RunMetaGraphWithConfig( 245 config=config_pb2.ConfigProto(), 246 metagraph=original_metagraph, 247 init_op_name=init_op_name, 248 train_op_name=train_op_name, 249 loss_op_name=loss_op_name) 250 memory_optimized_loss = self._RunMetaGraphWithConfig( 251 config=self._GetMemoryOptimizerSessionConfig(), 252 metagraph=original_metagraph, 253 init_op_name=init_op_name, 254 train_op_name=train_op_name, 255 loss_op_name=loss_op_name) 256 self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2) 257 258 def _annotated_graph(self): 259 graph = ops.Graph() 260 with graph.as_default(): 261 random_seed.set_random_seed(2) 262 current_activation = variable_scope.get_variable( 263 name='start', shape=[1, 2, 2, 5]) 264 conv_filter = variable_scope.get_variable( 265 name='filter', shape=[5, 5, 5, 5]) 266 for layer_number in range(3): 267 with variable_scope.variable_scope('layer_{}'.format(layer_number)): 268 after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1], 269 'SAME') 270 current_activation = 2. * after_conv 271 current_activation.op._set_attr( 272 '_recompute_hint', 273 # The value of the attribute does not matter; just that the key 274 # exists in the op's attributes. 275 attr_value_pb2.AttrValue(i=1)) 276 current_activation += 5. 277 current_activation.op._set_attr( 278 '_recompute_hint', attr_value_pb2.AttrValue(i=0)) 279 current_activation = nn.relu(current_activation) 280 current_activation.op._set_attr( 281 '_recompute_hint', attr_value_pb2.AttrValue(i=1)) 282 loss = math_ops.reduce_mean(current_activation) 283 optimizer = train.AdamOptimizer(0.001) 284 train_op = optimizer.minimize(loss) 285 init_op = variables.global_variables_initializer() 286 return graph, init_op, train_op 287 288 def testHintNoMetaGraph(self): 289 # Closer to expected usage, but does not check that a re-write actually 290 # happens; see testHintDoesRewrite. 291 graph, init_op, train_op = self._annotated_graph() 292 with graph.as_default(): 293 manual_memory_config = rewriter_config_pb2.RewriterConfig( 294 memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) 295 graph_options = config_pb2.GraphOptions( 296 rewrite_options=manual_memory_config) 297 session_config = config_pb2.ConfigProto(graph_options=graph_options) 298 with session.Session(config=session_config) as sess: 299 self.evaluate(init_op) 300 self.evaluate(train_op) 301 302 @test_util.run_v1_only('b/120545219') 303 def testHintDoesRewrite(self): 304 graph = self._annotated_graph()[0] 305 with graph.as_default(): 306 metagraph = train.export_meta_graph() 307 self.assertEqual( 308 0, 309 len([node for node in metagraph.graph_def.node 310 if 'Recomputed/' in node.name])) 311 config = config_pb2.ConfigProto() 312 config.graph_options.rewrite_options.CopyFrom( 313 rewriter_config_pb2.RewriterConfig( 314 min_graph_nodes=-1, 315 memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)) 316 rewritten_graph_def = tf_optimizer.OptimizeGraph(config, metagraph) 317 self.assertEqual( 318 9, 319 len([ 320 node for node in rewritten_graph_def.node 321 if 'Recomputed/' in node.name 322 ])) 323 324 if __name__ == '__main__': 325 test.main() 326