1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the swig wrapper tf_optimizer.""" 16 17from tensorflow.core.framework import attr_value_pb2 18from tensorflow.core.protobuf import config_pb2 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import meta_graph 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import test_util 25from tensorflow.python.grappler import item as gitem 26from tensorflow.python.grappler import tf_optimizer 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import test 32 33 34class PyWrapOptimizeGraphTest(test.TestCase): 35 36 @test_util.run_deprecated_v1 37 def testBasic(self): 38 """Make sure arguments can be passed correctly.""" 39 a = constant_op.constant(10, name='a') 40 b = constant_op.constant(20, name='b') 41 c = math_ops.add_n([a, b], name='c') 42 d = math_ops.add_n([b, c], name='d') 43 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 44 # Being a train_op will make 'd' to be added as a fetch node. 45 train_op.append(d) 46 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 47 48 config = config_pb2.ConfigProto() 49 rewriter_config = config.graph_options.rewrite_options 50 rewriter_config.optimizers.append('constfold') 51 rewriter_config.min_graph_nodes = -1 52 53 graph = tf_optimizer.OptimizeGraph(config, mg) 54 55 self.assertEqual(len(graph.node), 1) 56 self.assertItemsEqual([node.name for node in graph.node], ['d']) 57 58 @test_util.run_v1_only('b/120545219') 59 def testKeepNodes(self): 60 g = ops.Graph() 61 with g.as_default(): 62 a1 = variables.VariableV1( 63 1.0) # Must be preserved since it's in the collection 'variables'. 64 a2 = constant_op.constant(0, shape=[50, 50], name='keep') 65 ops.add_to_collection('a2', a2) # Explicitly add to collection. 66 with g._attr_scope( 67 {'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}): 68 a3 = constant_op.constant(0, name='keep2') 69 b = constant_op.constant(1, shape=[100, 10]) 70 c = constant_op.constant(0, shape=[10, 30]) 71 d = math_ops.matmul(b, c) 72 ops.add_to_collection('train_op', d) # d is the fetch node. 73 74 # Optimize the graph. 75 mg = meta_graph.create_meta_graph_def(graph=g) 76 config = config_pb2.ConfigProto() 77 rewriter_config = config.graph_options.rewrite_options 78 rewriter_config.min_graph_nodes = -1 79 optimized_graph = tf_optimizer.OptimizeGraph(config, mg) 80 81 # Check that the nodes referenced in various collections have been preserved 82 optimized_graph_nodes = [node.name for node in optimized_graph.node] 83 expected_nodes = [ 84 d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value', 85 'Variable/Assign' 86 ] 87 self.assertEqual(len(optimized_graph_nodes), len(expected_nodes)) 88 self.assertAllInSet(optimized_graph_nodes, expected_nodes) 89 90 @test_util.run_v1_only('b/120545219') 91 def testLoops(self): 92 g = ops.Graph() 93 with g.as_default(): 94 95 def _Cond(_, counter): 96 return counter < end 97 98 def _Body(buf, counter): 99 buf = array_ops.concat([buf, [counter]], 0) 100 counter += 1 101 return [buf, counter] 102 103 start = array_ops.placeholder(shape=[], dtype=dtypes.int32) 104 end = array_ops.placeholder(shape=[], dtype=dtypes.int32) 105 init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32) 106 loop_vars = [init_buf, start] 107 shape_inv = [ 108 tensor_shape.TensorShape([None]), 109 tensor_shape.TensorShape([]) 110 ] 111 buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv) 112 113 f = -array_ops.ones_like(buf, optimize=False) # pylint: disable=invalid-unary-operand-type 114 buf_shape = array_ops.shape(buf) 115 f_shape = array_ops.shape(f) 116 ops.add_to_collection('train_op', buf_shape) 117 ops.add_to_collection('train_op', f_shape) 118 119 # Optimize the graph. 120 mg = meta_graph.create_meta_graph_def(graph=g) 121 config = config_pb2.ConfigProto() 122 rewriter_config = config.graph_options.rewrite_options 123 rewriter_config.min_graph_nodes = -1 124 optimized_graph = tf_optimizer.OptimizeGraph(config, mg) 125 mg.graph_def.CopyFrom(optimized_graph) 126 127 # Check that the nodes referenced in various collections have been preserved 128 item = gitem.Item(mg) 129 props = item.GetOpProperties() 130 buf_prop = props[buf.op.name] 131 f_prop = props[f.op.name] 132 self.assertEqual(buf_prop, f_prop) 133 134 135if __name__ == '__main__': 136 test.main() 137