xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/cluster_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the swig wrapper of clusters."""
16
17from tensorflow.core.protobuf import device_properties_pb2
18from tensorflow.python.framework import meta_graph
19from tensorflow.python.framework import ops
20from tensorflow.python.grappler import cluster
21from tensorflow.python.grappler import item
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.platform import test
26
27
28class ClusterTest(test.TestCase):
29
30  def testBasic(self):
31    with ops.Graph().as_default() as g:
32      a = random_ops.random_uniform(shape=())
33      b = random_ops.random_uniform(shape=())
34      c = a + b
35      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
36      train_op.append(c)
37      mg = meta_graph.create_meta_graph_def(graph=g)
38      grappler_item = item.Item(mg)
39      grappler_cluster = cluster.Cluster(
40          disable_detailed_stats=False, disable_timeline=False)
41      op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
42          grappler_item)
43      self.assertTrue(run_time > 0)
44      self.assertEqual(len(op_perfs), 4)
45      self.assertTrue(step_stats.dev_stats)
46
47  def testNoDetailedStats(self):
48    with ops.Graph().as_default() as g:
49      a = random_ops.random_uniform(shape=())
50      b = random_ops.random_uniform(shape=())
51      c = a + b
52      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
53      train_op.append(c)
54      mg = meta_graph.create_meta_graph_def(graph=g)
55      grappler_item = item.Item(mg)
56      grappler_cluster = cluster.Cluster(disable_detailed_stats=True)
57
58      op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
59          grappler_item)
60      self.assertTrue(run_time > 0)
61      self.assertEqual(len(op_perfs), 0)
62      self.assertEqual(len(step_stats.dev_stats), 0)
63
64  def testMemoryEstimates(self):
65    with ops.Graph().as_default() as g:
66      with ops.device('/job:localhost/replica:0/task:0/device:CPU:0'):
67        a = random_ops.random_uniform(shape=())
68        b = random_ops.random_uniform(shape=())
69        c = a + b
70        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
71        train_op.append(c)
72        mg = meta_graph.create_meta_graph_def(graph=g)
73        grappler_item = item.Item(mg)
74        grappler_cluster = cluster.Cluster(
75            disable_detailed_stats=True, disable_timeline=True)
76        peak_mem = grappler_cluster.DeterminePeakMemoryUsage(grappler_item)
77        self.assertLessEqual(1, len(peak_mem))
78        snapshot = peak_mem['/job:localhost/replica:0/task:0/device:CPU:0']
79        peak_usage = snapshot[0]
80        self.assertEqual(12, peak_usage)
81        live_tensors = snapshot[1]
82        self.assertEqual(5, len(live_tensors))
83
84  def testVirtualCluster(self):
85    with ops.Graph().as_default() as g:
86      with ops.device('/device:GPU:0'):
87        a = random_ops.random_uniform(shape=[1024, 1024])
88        b = random_ops.random_uniform(shape=[1024, 1024])
89        c = a + b
90      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
91      train_op.append(c)
92      mg = meta_graph.create_meta_graph_def(graph=g)
93      grappler_item = item.Item(mg)
94      device_properties = device_properties_pb2.DeviceProperties(
95          type='GPU',
96          frequency=1000,
97          num_cores=60,
98          environment={'architecture': '7'})
99      named_device = device_properties_pb2.NamedDevice(
100          properties=device_properties, name='/device:GPU:0')
101      grappler_cluster = cluster.Cluster(
102          disable_detailed_stats=False,
103          disable_timeline=False,
104          devices=[named_device])
105      op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
106      self.assertEqual(run_time, 0.000209)
107      self.assertEqual(len(op_perfs), 5)
108
109      estimated_perf = grappler_cluster.EstimatePerformance(named_device)
110      self.assertEqual(7680.0, estimated_perf)
111
112  def testContext(self):
113    with ops.Graph().as_default() as g:
114      a = random_ops.random_uniform(shape=())
115      b = random_ops.random_uniform(shape=())
116      c = a + b
117      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
118      train_op.append(c)
119      mg = meta_graph.create_meta_graph_def(graph=g)
120      grappler_item = item.Item(mg)
121
122    with cluster.Provision(
123        disable_detailed_stats=False, disable_timeline=False) as gcluster:
124      op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item)
125      self.assertTrue(run_time > 0)
126      self.assertEqual(len(op_perfs), 4)
127      self.assertTrue(step_stats.dev_stats)
128
129  def testAvailableOps(self):
130    with cluster.Provision() as gcluster:
131      op_names = gcluster.ListAvailableOps()
132      self.assertTrue('Add' in op_names)
133      self.assertTrue('MatMul' in op_names)
134      self.assertEqual(op_names, sorted(op_names))
135
136  def testSupportDevices(self):
137    with ops.Graph().as_default() as g:
138      a = random_ops.random_uniform(shape=(2, 3))
139      b = random_ops.random_uniform(shape=(2, 3))
140      c = a + b
141      dims = math_ops.range(0, array_ops.rank(c), 1)
142      d = math_ops.reduce_sum(a, axis=dims)
143      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
144      train_op.append(d)
145      mg = meta_graph.create_meta_graph_def(graph=g)
146      grappler_item = item.Item(mg)
147
148      device_properties = device_properties_pb2.DeviceProperties(
149          type='GPU', frequency=1000, num_cores=60)
150      named_gpu = device_properties_pb2.NamedDevice(
151          properties=device_properties, name='/GPU:0')
152      device_properties = device_properties_pb2.DeviceProperties(
153          type='CPU', frequency=3000, num_cores=6)
154      named_cpu = device_properties_pb2.NamedDevice(
155          properties=device_properties, name='/CPU:0')
156      virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu])
157      supported_dev = virtual_cluster.GetSupportedDevices(grappler_item)
158      self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0'])
159      self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0'])
160      self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0'])
161
162      real_cluster = cluster.Cluster()
163      supported_dev = real_cluster.GetSupportedDevices(grappler_item)
164      if test.is_gpu_available():
165        self.assertEqual(supported_dev['add'], [
166            '/job:localhost/replica:0/task:0/device:CPU:0',
167            '/job:localhost/replica:0/task:0/device:GPU:0'
168        ])
169        self.assertEqual(supported_dev['Sum'], [
170            '/job:localhost/replica:0/task:0/device:CPU:0',
171            '/job:localhost/replica:0/task:0/device:GPU:0'
172        ])
173        # The axis tensor must reside on the host
174        self.assertEqual(supported_dev['range'],
175                         ['/job:localhost/replica:0/task:0/device:CPU:0'])
176      else:
177        self.assertEqual(supported_dev['add'],
178                         ['/job:localhost/replica:0/task:0/device:CPU:0'])
179
180
181if __name__ == '__main__':
182  test.main()
183