1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values.""" 16 17import copy 18import weakref 19 20from tensorflow.python.distribute import device_util 21from tensorflow.python.distribute import tpu_util 22from tensorflow.python.distribute import values_util 23from tensorflow.python.eager import context 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import resource_variable_ops 27from tensorflow.python.ops import variables as variables_lib 28 29 30# pylint: disable=protected-access 31 32 33class DistributedVariable(resource_variable_ops.BaseResourceVariable): 34 """Represents variables that are replicated. 35 36 It behaves exactly as a normal variable, but uses corresponding variable 37 handle based on the context. 38 - In each replica, it uses the handle from that replica. 39 - In tpu.replicate(), it uses the replicated handle. 40 - Otherwise, it uses the handle from the primary replica. 41 42 Note that it doesn't synchronize automatically as the old DistributedVariable 43 in values.py. 44 """ 45 46 def __init__(self, variables, *, enable_packed_handle=False): 47 if enable_packed_handle and not ops.executing_eagerly_outside_functions(): 48 raise ValueError( 49 "Argument `enable_packed_handle` is true, but packed handle is only " 50 "supported in eager mode. Please make sure eager execution is " 51 "enabled.") 52 self._variables = variables 53 if enable_packed_handle: 54 self._packed_handle = ops.pack_eager_tensors( 55 [v.handle for v in variables]) 56 else: 57 self._packed_handle = None 58 for v in variables: 59 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 60 self._device_to_handle = {v.device: v.handle for v in variables} 61 self._primary_handle = variables[0].handle 62 with ops.init_scope(), \ 63 ops.name_scope("DistributedVariable", skip_on_eager=False) as name: 64 handle_name = ops.name_from_scope_name(name) 65 self._unique_id = "%s_%d" % (handle_name, ops.uid()) 66 if context.executing_eagerly(): 67 initial_value = None 68 initializer = None 69 else: 70 initial_value = variables[0].initial_value 71 initializer = control_flow_ops.group([v.initializer for v in variables]) 72 super().__init__( 73 trainable=variables[0].trainable, 74 shape=variables[0].shape, 75 dtype=variables[0].dtype, 76 handle=None, 77 synchronization=variables[0].synchronization, 78 constraint=variables[0].constraint, 79 aggregation=variables[0].aggregation, 80 distribute_strategy=variables[0]._distribute_strategy, 81 name=variables[0].name, 82 unique_id=self._unique_id, 83 handle_name=handle_name, 84 graph_element=variables[0]._graph_element, 85 initial_value=initial_value, 86 initializer_op=initializer, 87 is_initialized_op=None, 88 cached_value=None, 89 caching_device=None, 90 is_variables=True) 91 92 @property 93 def handle(self): 94 if values_util.is_saving_non_distributed(): 95 return self._primary_handle 96 tpu_context = tpu_util.enclosing_tpu_context() 97 if tpu_context and not context.executing_eagerly(): 98 is_mirrored = ( 99 self._variables[0].synchronization != 100 variables_lib.VariableSynchronization.ON_READ) 101 if self._packed_handle is None: 102 handles = [v.handle for v in self._variables] 103 is_packed = False 104 else: 105 handles = [self._packed_handle] 106 is_packed = True 107 common_name = self._handle_name 108 # BaseResourceVariable appends ":0" to the handle name, which makes it not 109 # a valid root scope name. 110 if ":" in common_name: 111 common_name = common_name.split(":")[0] 112 return tpu_context.get_replicated_var_handle(common_name, self._unique_id, 113 handles, is_mirrored, 114 is_packed) 115 if self._packed_handle is not None and not context.executing_eagerly(): 116 return self._packed_handle 117 device = device_util.canonicalize(device_util.current()) 118 return self._device_to_handle.get(device, self._primary_handle) 119 120 @property 121 def name(self): 122 if values_util.is_saving_non_distributed(): 123 return self._variables[0].name 124 return super().name 125 126 @property 127 def initializer(self): 128 if values_util.is_saving_non_distributed(): 129 return self._variables[0].initializer 130 return super().initializer 131 132 def _lazy_read(self, op): 133 # Lazy read is not supported. 134 with ops.control_dependencies([op]): 135 return self.read_value() 136 137 # Begin overrides of read/write methods to satisfy the requirement of using 138 # packed handle, i.e. there must be explicit device annotations. 139 140 def _device_scope(self): 141 if (self._packed_handle is None or 142 values_util.is_saving_non_distributed() or 143 tpu_util.enclosing_tpu_context() is not None): 144 return ops.NullContextmanager() 145 device = device_util.canonicalize(device_util.current()) 146 if device in self._device_to_handle: 147 return ops.NullContextmanager() 148 return ops.device(self._primary_handle.device) 149 150 def value(self): 151 # We always force a read_value() instead of using the cached_value, as 152 # value() can be called on different devices. 153 return self.read_value() 154 155 def read_value(self): 156 with self._device_scope(): 157 return super().read_value() 158 159 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 160 with self._device_scope(): 161 return super().assign_sub(delta, use_locking, name, read_value) 162 163 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 164 with self._device_scope(): 165 return super().assign_add(delta, use_locking, name, read_value) 166 167 def assign(self, value, use_locking=None, name=None, read_value=True): 168 with self._device_scope(): 169 return super().assign(value, use_locking, name, read_value) 170 171 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 172 with self._device_scope(): 173 return super().scatter_sub(sparse_delta, use_locking, name) 174 175 def scatter_add(self, sparse_delta, use_locking=False, name=None): 176 with self._device_scope(): 177 return super().scatter_add(sparse_delta, use_locking, name) 178 179 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 180 with self._device_scope(): 181 return super().scatter_mul(sparse_delta, use_locking, name) 182 183 def scatter_div(self, sparse_delta, use_locking=False, name=None): 184 with self._device_scope(): 185 return super().scatter_div(sparse_delta, use_locking, name) 186 187 def scatter_min(self, sparse_delta, use_locking=False, name=None): 188 with self._device_scope(): 189 return super().scatter_min(sparse_delta, use_locking, name) 190 191 def scatter_max(self, sparse_delta, use_locking=False, name=None): 192 with self._device_scope(): 193 return super().scatter_max(sparse_delta, use_locking, name) 194 195 def scatter_update(self, sparse_delta, use_locking=False, name=None): 196 with self._device_scope(): 197 return super().scatter_update(sparse_delta, use_locking, name) 198 199 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 200 with self._device_scope(): 201 return super().batch_scatter_update(sparse_delta, use_locking, name) 202 203 def scatter_nd_sub(self, indices, updates, name=None): 204 with self._device_scope(): 205 return super().scatter_nd_sub(indices, updates, name) 206 207 def scatter_nd_add(self, indices, updates, name=None): 208 with self._device_scope(): 209 return super().scatter_nd_add(indices, updates, name) 210 211 def scatter_nd_update(self, indices, updates, name=None): 212 with self._device_scope(): 213 return super().scatter_nd_update(indices, updates, name) 214 215 def sparse_read(self, indices, name=None): 216 with self._device_scope(): 217 return super().sparse_read(indices, name) 218 219 def gather_nd(self, indices, name=None): 220 with self._device_scope(): 221 return super().gather_nd(indices, name) 222 223 def to_proto(self, export_scope=None): 224 del self 225 raise TypeError("DistributedVariable doesn't support to_proto") 226 227 @staticmethod 228 def from_proto(variable_def, import_scope=None): 229 raise TypeError("DistributedVariable doesn't support from_proto") 230 231 def _as_graph_element(self): 232 if ops.get_default_graph().finalized: 233 return self._variables[0]._graph_element 234 return self.read_value() 235 236 def _strided_slice_assign(self, *args, **kwargs): 237 with self._device_scope(): 238 return super()._strided_slice_assign(*args, **kwargs) 239 240 def __str__(self): 241 debug_str = ",\n".join( 242 " %d: %s" % (i, v) for i, v in enumerate(self._variables)) 243 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 244 245 def __repr__(self): 246 debug_repr = ",\n".join( 247 " %d: %r" % (i, v) for i, v in enumerate(self._variables)) 248 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 249 250 def __deepcopy__(self, memo): 251 copied_variables = copy.deepcopy(self._variables, memo) 252 return DistributedVariable( 253 copied_variables, enable_packed_handle=self._packed_handle is not None) 254 255 256def _tensor_conversion(var, dtype=None, name=None, as_ref=False): 257 if as_ref: 258 raise ValueError( 259 "You may be using variable created under distribute strategy in TF " 260 "1.x control flows. Try explicitly converting the variable to Tensor " 261 "using variable.read_value(), or switch to TF 2.x.") 262 return ops.convert_to_tensor( 263 var.read_value(), dtype=dtype, name=name, as_ref=as_ref) 264 265 266ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) 267