1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the cost analyzer.""" 16 17import re 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import meta_graph 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import test_util 24from tensorflow.python.grappler import cost_analyzer 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 28from tensorflow.python.ops import nn_ops 29from tensorflow.python.ops import random_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import test 32from tensorflow.python.training import adam 33 34 35class CostAnalysisTest(test.TestCase): 36 37 @test_util.run_deprecated_v1 38 def testBasicCost(self): 39 """Make sure arguments can be passed correctly.""" 40 a = constant_op.constant(10, name="a") 41 b = constant_op.constant(20, name="b") 42 c = math_ops.add_n([a, b], name="c") 43 d = math_ops.add_n([b, c], name="d") 44 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 45 train_op.append(d) 46 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 47 48 report = cost_analyzer.GenerateCostReport(mg, per_node_report=True) 49 50 # Check the report headers 51 self.assertTrue(b"Total time measured in ns (serialized):" in report) 52 self.assertTrue(b"Total time measured in ns (actual):" in report) 53 self.assertTrue(b"Total time analytical in ns (upper bound):" in report) 54 self.assertTrue(b"Total time analytical in ns (lower bound):" in report) 55 self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report) 56 self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report) 57 self.assertTrue(b"Below is the per-node report summary:" in report) 58 59 # Also print the report to make it easier to debug 60 print("{}".format(report)) 61 62 @test_util.run_deprecated_v1 63 def testVerbose(self): 64 """Make sure the full report is generated with verbose=True.""" 65 a = constant_op.constant(10, name="a") 66 b = constant_op.constant(20, name="b") 67 c = math_ops.add_n([a, b], name="c") 68 d = math_ops.add_n([b, c], name="d") 69 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 70 train_op.append(d) 71 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 72 73 report = cost_analyzer.GenerateCostReport( 74 mg, per_node_report=True, verbose=True) 75 76 # Check the report headers 77 self.assertTrue(b"Below is the full per-node report:" in report) 78 79 # Also print the report to make it easier to debug 80 print("{}".format(report)) 81 82 @test_util.run_deprecated_v1 83 def testSmallNetworkCost(self): 84 image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1]) 85 label = array_ops.placeholder(dtypes.float32, shape=[1, 10]) 86 w = variables.Variable( 87 random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1)) 88 b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1)) 89 conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME") 90 h_conv = nn_ops.relu(conv + b) 91 h_conv_flat = array_ops.reshape(h_conv, [1, -1]) 92 93 w_fc = variables.Variable( 94 random_ops.truncated_normal([25088, 10], stddev=0.1)) 95 b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1)) 96 y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc) 97 98 cross_entropy = math_ops.reduce_mean( 99 -math_ops.reduce_sum(label * math_ops.log(y_conv), axis=[1])) 100 _ = adam.AdamOptimizer(1e-4).minimize(cross_entropy) 101 102 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 103 report = cost_analyzer.GenerateCostReport(mg) 104 105 # Print the report to make it easier to debug 106 print("{}".format(report)) 107 108 self.assertTrue(b"MatMul" in report) 109 self.assertTrue(b"ApplyAdam" in report) 110 self.assertTrue(b"Conv2DBackpropFilter" in report) 111 self.assertTrue(b"Softmax" in report) 112 113 # When mkl is enabled, Conv2D and MatMul op followed by 114 # 1-dimension Add in this graph will be fused, but not 115 # in the mkl disabled case. 116 expected_matmul_count = 2 117 op_types = [b"MatMul", b"Conv2DBackpropFilter"] 118 119 if not test_util.IsMklEnabled(): 120 self.assertTrue(b"Conv2D" in report) 121 expected_matmul_count = 3 122 op_types.append(b"Conv2D") 123 124 for op_type in op_types: 125 matcher = re.compile( 126 br"\s+" + op_type + br",\s*(\d+),\s*(\d+),\s*([\d\.eE+-]+)%,\s*" + 127 br"([\d\.eE+-]+)%,\s*(-?\d+),\s*(\d+),", re.MULTILINE) 128 m = matcher.search(report) 129 130 op_count = int(m.group(1)) 131 # upper = int(m.group(5)) 132 lower = int(m.group(6)) 133 if op_type == b"MatMul": 134 self.assertEqual(expected_matmul_count, op_count) 135 else: 136 self.assertEqual(1, op_count) 137 self.assertTrue(0 <= lower) 138 # self.assertTrue(0 < upper) 139 # self.assertTrue(lower <= upper) 140 141 @test_util.run_deprecated_v1 142 def testBasicMemory(self): 143 """Make sure arguments can be passed correctly.""" 144 with test_util.device(use_gpu=False): 145 a = constant_op.constant(10, name="a") 146 b = constant_op.constant(20, name="b") 147 c = math_ops.add_n([a, b], name="c") 148 d = math_ops.add_n([b, c], name="d") 149 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 150 train_op.append(d) 151 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 152 153 report = cost_analyzer.GenerateMemoryReport(mg) 154 155 # Print the report to make it easier to debug 156 print("{}".format(report)) 157 158 # Check the report 159 self.assertTrue( 160 "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: " 161 "16 bytes" 162 in report) 163 self.assertTrue(" a:0 uses 4 bytes" in report) 164 self.assertTrue(" b:0 uses 4 bytes" in report) 165 self.assertTrue(" c:0 uses 4 bytes" in report) 166 self.assertTrue(" d:0 uses 4 bytes" in report) 167 168 169if __name__ == "__main__": 170 test.main() 171