xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/cost_analyzer_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the cost analyzer."""
16
17import re
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import meta_graph
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.grappler import cost_analyzer
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
28from tensorflow.python.ops import nn_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import test
32from tensorflow.python.training import adam
33
34
35class CostAnalysisTest(test.TestCase):
36
37  @test_util.run_deprecated_v1
38  def testBasicCost(self):
39    """Make sure arguments can be passed correctly."""
40    a = constant_op.constant(10, name="a")
41    b = constant_op.constant(20, name="b")
42    c = math_ops.add_n([a, b], name="c")
43    d = math_ops.add_n([b, c], name="d")
44    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
45    train_op.append(d)
46    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
47
48    report = cost_analyzer.GenerateCostReport(mg, per_node_report=True)
49
50    # Check the report headers
51    self.assertTrue(b"Total time measured in ns (serialized):" in report)
52    self.assertTrue(b"Total time measured in ns (actual):" in report)
53    self.assertTrue(b"Total time analytical in ns (upper bound):" in report)
54    self.assertTrue(b"Total time analytical in ns (lower bound):" in report)
55    self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report)
56    self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report)
57    self.assertTrue(b"Below is the per-node report summary:" in report)
58
59    # Also print the report to make it easier to debug
60    print("{}".format(report))
61
62  @test_util.run_deprecated_v1
63  def testVerbose(self):
64    """Make sure the full report is generated with verbose=True."""
65    a = constant_op.constant(10, name="a")
66    b = constant_op.constant(20, name="b")
67    c = math_ops.add_n([a, b], name="c")
68    d = math_ops.add_n([b, c], name="d")
69    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
70    train_op.append(d)
71    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
72
73    report = cost_analyzer.GenerateCostReport(
74        mg, per_node_report=True, verbose=True)
75
76    # Check the report headers
77    self.assertTrue(b"Below is the full per-node report:" in report)
78
79    # Also print the report to make it easier to debug
80    print("{}".format(report))
81
82  @test_util.run_deprecated_v1
83  def testSmallNetworkCost(self):
84    image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1])
85    label = array_ops.placeholder(dtypes.float32, shape=[1, 10])
86    w = variables.Variable(
87        random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1))
88    b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1))
89    conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME")
90    h_conv = nn_ops.relu(conv + b)
91    h_conv_flat = array_ops.reshape(h_conv, [1, -1])
92
93    w_fc = variables.Variable(
94        random_ops.truncated_normal([25088, 10], stddev=0.1))
95    b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1))
96    y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc)
97
98    cross_entropy = math_ops.reduce_mean(
99        -math_ops.reduce_sum(label * math_ops.log(y_conv), axis=[1]))
100    _ = adam.AdamOptimizer(1e-4).minimize(cross_entropy)
101
102    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
103    report = cost_analyzer.GenerateCostReport(mg)
104
105    # Print the report to make it easier to debug
106    print("{}".format(report))
107
108    self.assertTrue(b"MatMul" in report)
109    self.assertTrue(b"ApplyAdam" in report)
110    self.assertTrue(b"Conv2DBackpropFilter" in report)
111    self.assertTrue(b"Softmax" in report)
112
113    # When mkl is enabled, Conv2D and MatMul op followed by
114    # 1-dimension Add in this graph will be fused, but not
115    # in the mkl disabled case.
116    expected_matmul_count = 2
117    op_types = [b"MatMul", b"Conv2DBackpropFilter"]
118
119    if not test_util.IsMklEnabled():
120      self.assertTrue(b"Conv2D" in report)
121      expected_matmul_count = 3
122      op_types.append(b"Conv2D")
123
124    for op_type in op_types:
125      matcher = re.compile(
126          br"\s+" + op_type + br",\s*(\d+),\s*(\d+),\s*([\d\.eE+-]+)%,\s*" +
127          br"([\d\.eE+-]+)%,\s*(-?\d+),\s*(\d+),", re.MULTILINE)
128      m = matcher.search(report)
129
130      op_count = int(m.group(1))
131      # upper = int(m.group(5))
132      lower = int(m.group(6))
133      if op_type == b"MatMul":
134        self.assertEqual(expected_matmul_count, op_count)
135      else:
136        self.assertEqual(1, op_count)
137      self.assertTrue(0 <= lower)
138      # self.assertTrue(0 < upper)
139      # self.assertTrue(lower <= upper)
140
141  @test_util.run_deprecated_v1
142  def testBasicMemory(self):
143    """Make sure arguments can be passed correctly."""
144    with test_util.device(use_gpu=False):
145      a = constant_op.constant(10, name="a")
146      b = constant_op.constant(20, name="b")
147      c = math_ops.add_n([a, b], name="c")
148      d = math_ops.add_n([b, c], name="d")
149      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
150      train_op.append(d)
151      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
152
153    report = cost_analyzer.GenerateMemoryReport(mg)
154
155    # Print the report to make it easier to debug
156    print("{}".format(report))
157
158    # Check the report
159    self.assertTrue(
160        "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
161        "16 bytes"
162        in report)
163    self.assertTrue("  a:0 uses 4 bytes" in report)
164    self.assertTrue("  b:0 uses 4 bytes" in report)
165    self.assertTrue("  c:0 uses 4 bytes" in report)
166    self.assertTrue("  d:0 uses 4 bytes" in report)
167
168
169if __name__ == "__main__":
170  test.main()
171