xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/tf_optimizer_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the swig wrapper tf_optimizer."""
16
17from tensorflow.core.framework import attr_value_pb2
18from tensorflow.core.protobuf import config_pb2
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import meta_graph
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import test_util
25from tensorflow.python.grappler import item as gitem
26from tensorflow.python.grappler import tf_optimizer
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import test
32
33
34class PyWrapOptimizeGraphTest(test.TestCase):
35
36  @test_util.run_deprecated_v1
37  def testBasic(self):
38    """Make sure arguments can be passed correctly."""
39    a = constant_op.constant(10, name='a')
40    b = constant_op.constant(20, name='b')
41    c = math_ops.add_n([a, b], name='c')
42    d = math_ops.add_n([b, c], name='d')
43    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
44    # Being a train_op will make 'd' to be added as a fetch node.
45    train_op.append(d)
46    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
47
48    config = config_pb2.ConfigProto()
49    rewriter_config = config.graph_options.rewrite_options
50    rewriter_config.optimizers.append('constfold')
51    rewriter_config.min_graph_nodes = -1
52
53    graph = tf_optimizer.OptimizeGraph(config, mg)
54
55    self.assertEqual(len(graph.node), 1)
56    self.assertItemsEqual([node.name for node in graph.node], ['d'])
57
58  @test_util.run_v1_only('b/120545219')
59  def testKeepNodes(self):
60    g = ops.Graph()
61    with g.as_default():
62      a1 = variables.VariableV1(
63          1.0)  # Must be preserved since it's in the collection 'variables'.
64      a2 = constant_op.constant(0, shape=[50, 50], name='keep')
65      ops.add_to_collection('a2', a2)  # Explicitly add to collection.
66      with g._attr_scope(
67          {'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}):
68        a3 = constant_op.constant(0, name='keep2')
69      b = constant_op.constant(1, shape=[100, 10])
70      c = constant_op.constant(0, shape=[10, 30])
71      d = math_ops.matmul(b, c)
72      ops.add_to_collection('train_op', d)  # d is the fetch node.
73
74    # Optimize the graph.
75    mg = meta_graph.create_meta_graph_def(graph=g)
76    config = config_pb2.ConfigProto()
77    rewriter_config = config.graph_options.rewrite_options
78    rewriter_config.min_graph_nodes = -1
79    optimized_graph = tf_optimizer.OptimizeGraph(config, mg)
80
81    # Check that the nodes referenced in various collections have been preserved
82    optimized_graph_nodes = [node.name for node in optimized_graph.node]
83    expected_nodes = [
84        d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value',
85        'Variable/Assign'
86    ]
87    self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))
88    self.assertAllInSet(optimized_graph_nodes, expected_nodes)
89
90  @test_util.run_v1_only('b/120545219')
91  def testLoops(self):
92    g = ops.Graph()
93    with g.as_default():
94
95      def _Cond(_, counter):
96        return counter < end
97
98      def _Body(buf, counter):
99        buf = array_ops.concat([buf, [counter]], 0)
100        counter += 1
101        return [buf, counter]
102
103      start = array_ops.placeholder(shape=[], dtype=dtypes.int32)
104      end = array_ops.placeholder(shape=[], dtype=dtypes.int32)
105      init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32)
106      loop_vars = [init_buf, start]
107      shape_inv = [
108          tensor_shape.TensorShape([None]),
109          tensor_shape.TensorShape([])
110      ]
111      buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv)
112
113      f = -array_ops.ones_like(buf, optimize=False)  # pylint: disable=invalid-unary-operand-type
114      buf_shape = array_ops.shape(buf)
115      f_shape = array_ops.shape(f)
116      ops.add_to_collection('train_op', buf_shape)
117      ops.add_to_collection('train_op', f_shape)
118
119    # Optimize the graph.
120    mg = meta_graph.create_meta_graph_def(graph=g)
121    config = config_pb2.ConfigProto()
122    rewriter_config = config.graph_options.rewrite_options
123    rewriter_config.min_graph_nodes = -1
124    optimized_graph = tf_optimizer.OptimizeGraph(config, mg)
125    mg.graph_def.CopyFrom(optimized_graph)
126
127    # Check that the nodes referenced in various collections have been preserved
128    item = gitem.Item(mg)
129    props = item.GetOpProperties()
130    buf_prop = props[buf.op.name]
131    f_prop = props[f.op.name]
132    self.assertEqual(buf_prop, f_prop)
133
134
135if __name__ == '__main__':
136  test.main()
137