xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/values_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
16
17import copy
18import weakref
19
20from tensorflow.python.distribute import device_util
21from tensorflow.python.distribute import tpu_util
22from tensorflow.python.distribute import values_util
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import resource_variable_ops
27from tensorflow.python.ops import variables as variables_lib
28
29
30# pylint: disable=protected-access
31
32
33class DistributedVariable(resource_variable_ops.BaseResourceVariable):
34  """Represents variables that are replicated.
35
36  It behaves exactly as a normal variable, but uses corresponding variable
37  handle based on the context.
38  - In each replica, it uses the handle from that replica.
39  - In tpu.replicate(), it uses the replicated handle.
40  - Otherwise, it uses the handle from the primary replica.
41
42  Note that it doesn't synchronize automatically as the old DistributedVariable
43  in values.py.
44  """
45
46  def __init__(self, variables, *, enable_packed_handle=False):
47    if enable_packed_handle and not ops.executing_eagerly_outside_functions():
48      raise ValueError(
49          "Argument `enable_packed_handle` is true, but packed handle is only "
50          "supported in eager mode. Please make sure eager execution is "
51          "enabled.")
52    self._variables = variables
53    if enable_packed_handle:
54      self._packed_handle = ops.pack_eager_tensors(
55          [v.handle for v in variables])
56    else:
57      self._packed_handle = None
58    for v in variables:
59      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access
60    self._device_to_handle = {v.device: v.handle for v in variables}
61    self._primary_handle = variables[0].handle
62    with ops.init_scope(), \
63         ops.name_scope("DistributedVariable", skip_on_eager=False) as name:
64      handle_name = ops.name_from_scope_name(name)
65      self._unique_id = "%s_%d" % (handle_name, ops.uid())
66      if context.executing_eagerly():
67        initial_value = None
68        initializer = None
69      else:
70        initial_value = variables[0].initial_value
71        initializer = control_flow_ops.group([v.initializer for v in variables])
72      super().__init__(
73          trainable=variables[0].trainable,
74          shape=variables[0].shape,
75          dtype=variables[0].dtype,
76          handle=None,
77          synchronization=variables[0].synchronization,
78          constraint=variables[0].constraint,
79          aggregation=variables[0].aggregation,
80          distribute_strategy=variables[0]._distribute_strategy,
81          name=variables[0].name,
82          unique_id=self._unique_id,
83          handle_name=handle_name,
84          graph_element=variables[0]._graph_element,
85          initial_value=initial_value,
86          initializer_op=initializer,
87          is_initialized_op=None,
88          cached_value=None,
89          caching_device=None,
90          is_variables=True)
91
92  @property
93  def handle(self):
94    if values_util.is_saving_non_distributed():
95      return self._primary_handle
96    tpu_context = tpu_util.enclosing_tpu_context()
97    if tpu_context and not context.executing_eagerly():
98      is_mirrored = (
99          self._variables[0].synchronization !=
100          variables_lib.VariableSynchronization.ON_READ)
101      if self._packed_handle is None:
102        handles = [v.handle for v in self._variables]
103        is_packed = False
104      else:
105        handles = [self._packed_handle]
106        is_packed = True
107      common_name = self._handle_name
108      # BaseResourceVariable appends ":0" to the handle name, which makes it not
109      # a valid root scope name.
110      if ":" in common_name:
111        common_name = common_name.split(":")[0]
112      return tpu_context.get_replicated_var_handle(common_name, self._unique_id,
113                                                   handles, is_mirrored,
114                                                   is_packed)
115    if self._packed_handle is not None and not context.executing_eagerly():
116      return self._packed_handle
117    device = device_util.canonicalize(device_util.current())
118    return self._device_to_handle.get(device, self._primary_handle)
119
120  @property
121  def name(self):
122    if values_util.is_saving_non_distributed():
123      return self._variables[0].name
124    return super().name
125
126  @property
127  def initializer(self):
128    if values_util.is_saving_non_distributed():
129      return self._variables[0].initializer
130    return super().initializer
131
132  def _lazy_read(self, op):
133    # Lazy read is not supported.
134    with ops.control_dependencies([op]):
135      return self.read_value()
136
137  # Begin overrides of read/write methods to satisfy the requirement of using
138  # packed handle, i.e. there must be explicit device annotations.
139
140  def _device_scope(self):
141    if (self._packed_handle is None or
142        values_util.is_saving_non_distributed() or
143        tpu_util.enclosing_tpu_context() is not None):
144      return ops.NullContextmanager()
145    device = device_util.canonicalize(device_util.current())
146    if device in self._device_to_handle:
147      return ops.NullContextmanager()
148    return ops.device(self._primary_handle.device)
149
150  def value(self):
151    # We always force a read_value() instead of using the cached_value, as
152    # value() can be called on different devices.
153    return self.read_value()
154
155  def read_value(self):
156    with self._device_scope():
157      return super().read_value()
158
159  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
160    with self._device_scope():
161      return super().assign_sub(delta, use_locking, name, read_value)
162
163  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
164    with self._device_scope():
165      return super().assign_add(delta, use_locking, name, read_value)
166
167  def assign(self, value, use_locking=None, name=None, read_value=True):
168    with self._device_scope():
169      return super().assign(value, use_locking, name, read_value)
170
171  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
172    with self._device_scope():
173      return super().scatter_sub(sparse_delta, use_locking, name)
174
175  def scatter_add(self, sparse_delta, use_locking=False, name=None):
176    with self._device_scope():
177      return super().scatter_add(sparse_delta, use_locking, name)
178
179  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
180    with self._device_scope():
181      return super().scatter_mul(sparse_delta, use_locking, name)
182
183  def scatter_div(self, sparse_delta, use_locking=False, name=None):
184    with self._device_scope():
185      return super().scatter_div(sparse_delta, use_locking, name)
186
187  def scatter_min(self, sparse_delta, use_locking=False, name=None):
188    with self._device_scope():
189      return super().scatter_min(sparse_delta, use_locking, name)
190
191  def scatter_max(self, sparse_delta, use_locking=False, name=None):
192    with self._device_scope():
193      return super().scatter_max(sparse_delta, use_locking, name)
194
195  def scatter_update(self, sparse_delta, use_locking=False, name=None):
196    with self._device_scope():
197      return super().scatter_update(sparse_delta, use_locking, name)
198
199  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
200    with self._device_scope():
201      return super().batch_scatter_update(sparse_delta, use_locking, name)
202
203  def scatter_nd_sub(self, indices, updates, name=None):
204    with self._device_scope():
205      return super().scatter_nd_sub(indices, updates, name)
206
207  def scatter_nd_add(self, indices, updates, name=None):
208    with self._device_scope():
209      return super().scatter_nd_add(indices, updates, name)
210
211  def scatter_nd_update(self, indices, updates, name=None):
212    with self._device_scope():
213      return super().scatter_nd_update(indices, updates, name)
214
215  def sparse_read(self, indices, name=None):
216    with self._device_scope():
217      return super().sparse_read(indices, name)
218
219  def gather_nd(self, indices, name=None):
220    with self._device_scope():
221      return super().gather_nd(indices, name)
222
223  def to_proto(self, export_scope=None):
224    del self
225    raise TypeError("DistributedVariable doesn't support to_proto")
226
227  @staticmethod
228  def from_proto(variable_def, import_scope=None):
229    raise TypeError("DistributedVariable doesn't support from_proto")
230
231  def _as_graph_element(self):
232    if ops.get_default_graph().finalized:
233      return self._variables[0]._graph_element
234    return self.read_value()
235
236  def _strided_slice_assign(self, *args, **kwargs):
237    with self._device_scope():
238      return super()._strided_slice_assign(*args, **kwargs)
239
240  def __str__(self):
241    debug_str = ",\n".join(
242        "  %d: %s" % (i, v) for i, v in enumerate(self._variables))
243    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
244
245  def __repr__(self):
246    debug_repr = ",\n".join(
247        "  %d: %r" % (i, v) for i, v in enumerate(self._variables))
248    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
249
250  def __deepcopy__(self, memo):
251    copied_variables = copy.deepcopy(self._variables, memo)
252    return DistributedVariable(
253        copied_variables, enable_packed_handle=self._packed_handle is not None)
254
255
256def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
257  if as_ref:
258    raise ValueError(
259        "You may be using variable created under distribute strategy in TF "
260        "1.x control flows. Try explicitly converting the variable to Tensor "
261        "using variable.read_value(), or switch to TF 2.x.")
262  return ops.convert_to_tensor(
263      var.read_value(), dtype=dtype, name=name, as_ref=as_ref)
264
265
266ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion)
267