1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the swig wrapper of clusters.""" 16 17from tensorflow.core.protobuf import device_properties_pb2 18from tensorflow.python.framework import meta_graph 19from tensorflow.python.framework import ops 20from tensorflow.python.grappler import cluster 21from tensorflow.python.grappler import item 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import random_ops 25from tensorflow.python.platform import test 26 27 28class ClusterTest(test.TestCase): 29 30 def testBasic(self): 31 with ops.Graph().as_default() as g: 32 a = random_ops.random_uniform(shape=()) 33 b = random_ops.random_uniform(shape=()) 34 c = a + b 35 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 36 train_op.append(c) 37 mg = meta_graph.create_meta_graph_def(graph=g) 38 grappler_item = item.Item(mg) 39 grappler_cluster = cluster.Cluster( 40 disable_detailed_stats=False, disable_timeline=False) 41 op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( 42 grappler_item) 43 self.assertTrue(run_time > 0) 44 self.assertEqual(len(op_perfs), 4) 45 self.assertTrue(step_stats.dev_stats) 46 47 def testNoDetailedStats(self): 48 with ops.Graph().as_default() as g: 49 a = random_ops.random_uniform(shape=()) 50 b = random_ops.random_uniform(shape=()) 51 c = a + b 52 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 53 train_op.append(c) 54 mg = meta_graph.create_meta_graph_def(graph=g) 55 grappler_item = item.Item(mg) 56 grappler_cluster = cluster.Cluster(disable_detailed_stats=True) 57 58 op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( 59 grappler_item) 60 self.assertTrue(run_time > 0) 61 self.assertEqual(len(op_perfs), 0) 62 self.assertEqual(len(step_stats.dev_stats), 0) 63 64 def testMemoryEstimates(self): 65 with ops.Graph().as_default() as g: 66 with ops.device('/job:localhost/replica:0/task:0/device:CPU:0'): 67 a = random_ops.random_uniform(shape=()) 68 b = random_ops.random_uniform(shape=()) 69 c = a + b 70 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 71 train_op.append(c) 72 mg = meta_graph.create_meta_graph_def(graph=g) 73 grappler_item = item.Item(mg) 74 grappler_cluster = cluster.Cluster( 75 disable_detailed_stats=True, disable_timeline=True) 76 peak_mem = grappler_cluster.DeterminePeakMemoryUsage(grappler_item) 77 self.assertLessEqual(1, len(peak_mem)) 78 snapshot = peak_mem['/job:localhost/replica:0/task:0/device:CPU:0'] 79 peak_usage = snapshot[0] 80 self.assertEqual(12, peak_usage) 81 live_tensors = snapshot[1] 82 self.assertEqual(5, len(live_tensors)) 83 84 def testVirtualCluster(self): 85 with ops.Graph().as_default() as g: 86 with ops.device('/device:GPU:0'): 87 a = random_ops.random_uniform(shape=[1024, 1024]) 88 b = random_ops.random_uniform(shape=[1024, 1024]) 89 c = a + b 90 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 91 train_op.append(c) 92 mg = meta_graph.create_meta_graph_def(graph=g) 93 grappler_item = item.Item(mg) 94 device_properties = device_properties_pb2.DeviceProperties( 95 type='GPU', 96 frequency=1000, 97 num_cores=60, 98 environment={'architecture': '7'}) 99 named_device = device_properties_pb2.NamedDevice( 100 properties=device_properties, name='/device:GPU:0') 101 grappler_cluster = cluster.Cluster( 102 disable_detailed_stats=False, 103 disable_timeline=False, 104 devices=[named_device]) 105 op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item) 106 self.assertEqual(run_time, 0.000209) 107 self.assertEqual(len(op_perfs), 5) 108 109 estimated_perf = grappler_cluster.EstimatePerformance(named_device) 110 self.assertEqual(7680.0, estimated_perf) 111 112 def testContext(self): 113 with ops.Graph().as_default() as g: 114 a = random_ops.random_uniform(shape=()) 115 b = random_ops.random_uniform(shape=()) 116 c = a + b 117 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 118 train_op.append(c) 119 mg = meta_graph.create_meta_graph_def(graph=g) 120 grappler_item = item.Item(mg) 121 122 with cluster.Provision( 123 disable_detailed_stats=False, disable_timeline=False) as gcluster: 124 op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) 125 self.assertTrue(run_time > 0) 126 self.assertEqual(len(op_perfs), 4) 127 self.assertTrue(step_stats.dev_stats) 128 129 def testAvailableOps(self): 130 with cluster.Provision() as gcluster: 131 op_names = gcluster.ListAvailableOps() 132 self.assertTrue('Add' in op_names) 133 self.assertTrue('MatMul' in op_names) 134 self.assertEqual(op_names, sorted(op_names)) 135 136 def testSupportDevices(self): 137 with ops.Graph().as_default() as g: 138 a = random_ops.random_uniform(shape=(2, 3)) 139 b = random_ops.random_uniform(shape=(2, 3)) 140 c = a + b 141 dims = math_ops.range(0, array_ops.rank(c), 1) 142 d = math_ops.reduce_sum(a, axis=dims) 143 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 144 train_op.append(d) 145 mg = meta_graph.create_meta_graph_def(graph=g) 146 grappler_item = item.Item(mg) 147 148 device_properties = device_properties_pb2.DeviceProperties( 149 type='GPU', frequency=1000, num_cores=60) 150 named_gpu = device_properties_pb2.NamedDevice( 151 properties=device_properties, name='/GPU:0') 152 device_properties = device_properties_pb2.DeviceProperties( 153 type='CPU', frequency=3000, num_cores=6) 154 named_cpu = device_properties_pb2.NamedDevice( 155 properties=device_properties, name='/CPU:0') 156 virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu]) 157 supported_dev = virtual_cluster.GetSupportedDevices(grappler_item) 158 self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0']) 159 self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0']) 160 self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0']) 161 162 real_cluster = cluster.Cluster() 163 supported_dev = real_cluster.GetSupportedDevices(grappler_item) 164 if test.is_gpu_available(): 165 self.assertEqual(supported_dev['add'], [ 166 '/job:localhost/replica:0/task:0/device:CPU:0', 167 '/job:localhost/replica:0/task:0/device:GPU:0' 168 ]) 169 self.assertEqual(supported_dev['Sum'], [ 170 '/job:localhost/replica:0/task:0/device:CPU:0', 171 '/job:localhost/replica:0/task:0/device:GPU:0' 172 ]) 173 # The axis tensor must reside on the host 174 self.assertEqual(supported_dev['range'], 175 ['/job:localhost/replica:0/task:0/device:CPU:0']) 176 else: 177 self.assertEqual(supported_dev['add'], 178 ['/job:localhost/replica:0/task:0/device:CPU:0']) 179 180 181if __name__ == '__main__': 182 test.main() 183