xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/device_util_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for device utilities."""
16
17from absl.testing import parameterized
18
19from tensorflow.core.protobuf import tensorflow_server_pb2
20from tensorflow.python.distribute import combinations
21from tensorflow.python.distribute import device_util
22from tensorflow.python.distribute import multi_worker_test_base
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.platform import test
26from tensorflow.python.training import server_lib
27
28
29class DeviceUtilTest(test.TestCase, parameterized.TestCase):
30
31  def setUp(self):
32    super(DeviceUtilTest, self).setUp()
33    context._reset_context()  # pylint: disable=protected-access
34
35  @combinations.generate(
36      combinations.combine(mode="graph")
37  )
38  def testCurrentDeviceWithGlobalGraph(self):
39    with ops.device("/cpu:0"):
40      self.assertEqual(device_util.current(), "/device:CPU:0")
41
42    with ops.device("/job:worker"):
43      with ops.device("/cpu:0"):
44        self.assertEqual(device_util.current(), "/job:worker/device:CPU:0")
45
46    with ops.device("/cpu:0"):
47      with ops.device("/gpu:0"):
48        self.assertEqual(device_util.current(), "/device:GPU:0")
49
50  def testCurrentDeviceWithNonGlobalGraph(self):
51    with ops.Graph().as_default():
52      with ops.device("/cpu:0"):
53        self.assertEqual(device_util.current(), "/device:CPU:0")
54
55  def testCurrentDeviceWithEager(self):
56    with context.eager_mode():
57      with ops.device("/cpu:0"):
58        self.assertEqual(device_util.current(),
59                         "/job:localhost/replica:0/task:0/device:CPU:0")
60
61  @combinations.generate(combinations.combine(mode=["graph", "eager"]))
62  def testCanonicalizeWithoutDefaultDevice(self, mode):
63    if mode == "graph":
64      self.assertEqual(
65          device_util.canonicalize("/cpu:0"),
66          "/replica:0/task:0/device:CPU:0")
67    else:
68      self.assertEqual(
69          device_util.canonicalize("/cpu:0"),
70          "/job:localhost/replica:0/task:0/device:CPU:0")
71    self.assertEqual(
72        device_util.canonicalize("/job:worker/cpu:0"),
73        "/job:worker/replica:0/task:0/device:CPU:0")
74    self.assertEqual(
75        device_util.canonicalize("/job:worker/task:1/cpu:0"),
76        "/job:worker/replica:0/task:1/device:CPU:0")
77
78  @combinations.generate(combinations.combine(mode=["eager"]))
79  def testCanonicalizeWithoutDefaultDeviceCollectiveEnabled(self):
80    cluster_spec = server_lib.ClusterSpec(
81        multi_worker_test_base.create_cluster_spec(
82            has_chief=False, num_workers=1, num_ps=0, has_eval=False))
83    server_def = tensorflow_server_pb2.ServerDef(
84        cluster=cluster_spec.as_cluster_def(),
85        job_name="worker",
86        task_index=0,
87        protocol="grpc",
88        port=0)
89    context.context().enable_collective_ops(server_def)
90    self.assertEqual(
91        device_util.canonicalize("/cpu:0"),
92        "/job:worker/replica:0/task:0/device:CPU:0")
93
94  def testCanonicalizeWithDefaultDevice(self):
95    self.assertEqual(
96        device_util.canonicalize("/job:worker/task:1/cpu:0", default="/gpu:0"),
97        "/job:worker/replica:0/task:1/device:CPU:0")
98    self.assertEqual(
99        device_util.canonicalize("/job:worker/task:1", default="/gpu:0"),
100        "/job:worker/replica:0/task:1/device:GPU:0")
101    self.assertEqual(
102        device_util.canonicalize("/cpu:0", default="/job:worker"),
103        "/job:worker/replica:0/task:0/device:CPU:0")
104    self.assertEqual(
105        device_util.canonicalize(
106            "/job:worker/replica:0/task:1/device:CPU:0",
107            default="/job:chief/replica:0/task:1/device:CPU:0"),
108        "/job:worker/replica:0/task:1/device:CPU:0")
109
110  def testResolveWithDeviceScope(self):
111    with ops.device("/gpu:0"):
112      self.assertEqual(
113          device_util.resolve("/job:worker/task:1/cpu:0"),
114          "/job:worker/replica:0/task:1/device:CPU:0")
115      self.assertEqual(
116          device_util.resolve("/job:worker/task:1"),
117          "/job:worker/replica:0/task:1/device:GPU:0")
118    with ops.device("/job:worker"):
119      self.assertEqual(
120          device_util.resolve("/cpu:0"),
121          "/job:worker/replica:0/task:0/device:CPU:0")
122
123
124if __name__ == "__main__":
125  test.main()
126