1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for device utilities.""" 16 17from absl.testing import parameterized 18 19from tensorflow.core.protobuf import tensorflow_server_pb2 20from tensorflow.python.distribute import combinations 21from tensorflow.python.distribute import device_util 22from tensorflow.python.distribute import multi_worker_test_base 23from tensorflow.python.eager import context 24from tensorflow.python.framework import ops 25from tensorflow.python.platform import test 26from tensorflow.python.training import server_lib 27 28 29class DeviceUtilTest(test.TestCase, parameterized.TestCase): 30 31 def setUp(self): 32 super(DeviceUtilTest, self).setUp() 33 context._reset_context() # pylint: disable=protected-access 34 35 @combinations.generate( 36 combinations.combine(mode="graph") 37 ) 38 def testCurrentDeviceWithGlobalGraph(self): 39 with ops.device("/cpu:0"): 40 self.assertEqual(device_util.current(), "/device:CPU:0") 41 42 with ops.device("/job:worker"): 43 with ops.device("/cpu:0"): 44 self.assertEqual(device_util.current(), "/job:worker/device:CPU:0") 45 46 with ops.device("/cpu:0"): 47 with ops.device("/gpu:0"): 48 self.assertEqual(device_util.current(), "/device:GPU:0") 49 50 def testCurrentDeviceWithNonGlobalGraph(self): 51 with ops.Graph().as_default(): 52 with ops.device("/cpu:0"): 53 self.assertEqual(device_util.current(), "/device:CPU:0") 54 55 def testCurrentDeviceWithEager(self): 56 with context.eager_mode(): 57 with ops.device("/cpu:0"): 58 self.assertEqual(device_util.current(), 59 "/job:localhost/replica:0/task:0/device:CPU:0") 60 61 @combinations.generate(combinations.combine(mode=["graph", "eager"])) 62 def testCanonicalizeWithoutDefaultDevice(self, mode): 63 if mode == "graph": 64 self.assertEqual( 65 device_util.canonicalize("/cpu:0"), 66 "/replica:0/task:0/device:CPU:0") 67 else: 68 self.assertEqual( 69 device_util.canonicalize("/cpu:0"), 70 "/job:localhost/replica:0/task:0/device:CPU:0") 71 self.assertEqual( 72 device_util.canonicalize("/job:worker/cpu:0"), 73 "/job:worker/replica:0/task:0/device:CPU:0") 74 self.assertEqual( 75 device_util.canonicalize("/job:worker/task:1/cpu:0"), 76 "/job:worker/replica:0/task:1/device:CPU:0") 77 78 @combinations.generate(combinations.combine(mode=["eager"])) 79 def testCanonicalizeWithoutDefaultDeviceCollectiveEnabled(self): 80 cluster_spec = server_lib.ClusterSpec( 81 multi_worker_test_base.create_cluster_spec( 82 has_chief=False, num_workers=1, num_ps=0, has_eval=False)) 83 server_def = tensorflow_server_pb2.ServerDef( 84 cluster=cluster_spec.as_cluster_def(), 85 job_name="worker", 86 task_index=0, 87 protocol="grpc", 88 port=0) 89 context.context().enable_collective_ops(server_def) 90 self.assertEqual( 91 device_util.canonicalize("/cpu:0"), 92 "/job:worker/replica:0/task:0/device:CPU:0") 93 94 def testCanonicalizeWithDefaultDevice(self): 95 self.assertEqual( 96 device_util.canonicalize("/job:worker/task:1/cpu:0", default="/gpu:0"), 97 "/job:worker/replica:0/task:1/device:CPU:0") 98 self.assertEqual( 99 device_util.canonicalize("/job:worker/task:1", default="/gpu:0"), 100 "/job:worker/replica:0/task:1/device:GPU:0") 101 self.assertEqual( 102 device_util.canonicalize("/cpu:0", default="/job:worker"), 103 "/job:worker/replica:0/task:0/device:CPU:0") 104 self.assertEqual( 105 device_util.canonicalize( 106 "/job:worker/replica:0/task:1/device:CPU:0", 107 default="/job:chief/replica:0/task:1/device:CPU:0"), 108 "/job:worker/replica:0/task:1/device:CPU:0") 109 110 def testResolveWithDeviceScope(self): 111 with ops.device("/gpu:0"): 112 self.assertEqual( 113 device_util.resolve("/job:worker/task:1/cpu:0"), 114 "/job:worker/replica:0/task:1/device:CPU:0") 115 self.assertEqual( 116 device_util.resolve("/job:worker/task:1"), 117 "/job:worker/replica:0/task:1/device:GPU:0") 118 with ops.device("/job:worker"): 119 self.assertEqual( 120 device_util.resolve("/cpu:0"), 121 "/job:worker/replica:0/task:0/device:CPU:0") 122 123 124if __name__ == "__main__": 125 test.main() 126