xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/device_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Device-related support functions."""
16 
17 
18 
19 from tensorflow.python.eager import context
20 from tensorflow.python.framework import config
21 from tensorflow.python.framework import device as tf_device
22 from tensorflow.python.framework import ops
23 
24 
25 def canonicalize(d, default=None):
26   """Canonicalize device string.
27 
28   If d has missing components, the rest would be deduced from the `default`
29   argument or from '/replica:0/task:0/device:CPU:0'. For example:
30     If d = '/cpu:0', default='/job:worker/task:1', it returns
31       '/job:worker/replica:0/task:1/device:CPU:0'.
32     If d = '/cpu:0', default='/job:worker', it returns
33       '/job:worker/replica:0/task:0/device:CPU:0'.
34     If d = '/gpu:0', default=None, it returns
35       '/replica:0/task:0/device:GPU:0'.
36 
37   Note: This uses "job:localhost" as the default if executing eagerly.
38 
39   Args:
40     d: a device string or tf.config.LogicalDevice
41     default: a string for default device if d doesn't have all components.
42 
43   Returns:
44     a canonicalized device string.
45   """
46   if isinstance(d, context.LogicalDevice):
47     d = tf_device.DeviceSpec.from_string(d.name)
48   else:
49     d = tf_device.DeviceSpec.from_string(d)
50 
51   assert d.device_type is None or d.device_type == d.device_type.upper(), (
52       "Device type '%s' must be all-caps." % (d.device_type,))
53   # Fill in missing device fields using defaults.
54   result = tf_device.DeviceSpec(
55       replica=0, task=0, device_type="CPU", device_index=0)
56   if ops.executing_eagerly_outside_functions():
57     # Try to deduce job, replica and task in case it's in a multi worker setup.
58     # TODO(b/151452748): Using list_logical_devices is not always safe since it
59     # may return remote devices as well, but we're already doing this elsewhere.
60     host_cpu = tf_device.DeviceSpec.from_string(
61         config.list_logical_devices("CPU")[0].name)
62     if host_cpu.job:
63       result = result.make_merged_spec(host_cpu)
64     else:
65       # The default job is localhost if eager execution is enabled
66       result = result.replace(job="localhost")
67   if default:
68     # Overrides any defaults with values from the default device if given.
69     result = result.make_merged_spec(
70         tf_device.DeviceSpec.from_string(default))
71 
72   # Apply `d` last, so that it's values take precedence over the defaults.
73   result = result.make_merged_spec(d)
74   return result.to_string()
75 
76 
77 def canonicalize_without_job_and_task(d):
78   """Partially canonicalize device string.
79 
80   This returns device string from `d` without including job and task.
81   This is most useful for parameter server strategy where the device strings are
82   generated on the chief, but executed on workers.
83 
84    For example:
85     If d = '/cpu:0', default='/job:worker/task:1', it returns
86       '/replica:0/device:CPU:0'.
87     If d = '/cpu:0', default='/job:worker', it returns
88       '/replica:0/device:CPU:0'.
89     If d = '/gpu:0', default=None, it returns
90       '/replica:0/device:GPU:0'.
91 
92   Note: This uses "job:localhost" as the default if executing eagerly.
93 
94   Args:
95     d: a device string or tf.config.LogicalDevice
96 
97   Returns:
98     a partially canonicalized device string.
99   """
100   canonicalized_device = canonicalize(d)
101   spec = tf_device.DeviceSpec.from_string(canonicalized_device)
102   spec = spec.replace(job=None, task=None, replica=0)
103   return spec.to_string()
104 
105 
106 def resolve(d):
107   """Canonicalize `d` with current device as default."""
108   return canonicalize(d, default=current())
109 
110 
111 class _FakeNodeDef(object):
112   """A fake NodeDef for _FakeOperation."""
113 
114   __slots__ = ["op", "name"]
115 
116   def __init__(self):
117     self.op = ""
118     self.name = ""
119 
120 
121 class _FakeOperation(object):
122   """A fake Operation object to pass to device functions."""
123 
124   def __init__(self):
125     self.device = ""
126     self.type = ""
127     self.name = ""
128     self.node_def = _FakeNodeDef()
129 
130   def _set_device(self, device):
131     self.device = ops._device_string(device)  # pylint: disable=protected-access
132 
133   def _set_device_from_string(self, device_str):
134     self.device = device_str
135 
136 
137 def current():
138   """Return a string (not canonicalized) for the current device."""
139   # TODO(josh11b): Work out how this function interacts with ops.colocate_with.
140   if ops.executing_eagerly_outside_functions():
141     d = context.context().device_name
142   else:
143     op = _FakeOperation()
144     ops.get_default_graph()._apply_device_functions(op)  # pylint: disable=protected-access
145     d = op.device
146   return d
147 
148 
149 def get_host_for_device(device):
150   """Returns the corresponding host device for the given device."""
151   spec = tf_device.DeviceSpec.from_string(device)
152   return tf_device.DeviceSpec(
153       job=spec.job, replica=spec.replica, task=spec.task,
154       device_type="CPU", device_index=0).to_string()
155 
156 
157 def local_devices_from_num_gpus(num_gpus):
158   """Returns device strings for local GPUs or CPU."""
159   return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or
160           ("/device:CPU:0",))
161