1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed values library.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.distribute import combinations 20from tensorflow.python.distribute import strategy_combinations 21from tensorflow.python.distribute import test_util 22from tensorflow.python.distribute import values_v2 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import test 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import indexed_slices 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.ops import variables as variables_lib 31 32 33class _VariableInterfaceTestBase(test.TestCase, parameterized.TestCase): 34 # This test verifies that DistributedVariable/AutoSyncVariable conforms to 35 # Variable and ResourceVariable interface, i.e. the methods and properties are 36 # all defined. It verifies methods and properties that have the same code path 37 # under different replicas/devices as well. It is not intended to verify 38 # methods and properties that behave differently under different 39 # replicas/devices; those should be covered separate tests. 40 41 def create_variable(self, initial_value=1., **kwargs): 42 raise NotImplementedError 43 44 @property 45 def devices(self): 46 return ["CPU:0", "CPU:1"] 47 48 # ==== Begin Variable interface === 49 # Please follow the same order as methods and properties defined in 50 # tf.Variable. 51 52 def testStringify(self): 53 v = self.create_variable() 54 self.assertIsInstance(v.__str__(), str) 55 self.assertIsInstance(v.__repr__(), str) 56 57 def testDenseRead(self): 58 v = self.create_variable(1.) 59 self.assertEqual(v.value(), 1.) 60 self.assertEqual(v.read_value(), 1.) 61 62 def testShape(self): 63 v = self.create_variable([1.]) 64 self.assertEqual(v.shape, (1,)) 65 self.assertEqual(v.get_shape(), (1,)) 66 v.set_shape((1,)) 67 with self.assertRaisesRegex(ValueError, "not compatible"): 68 v.set_shape((1, 1)) 69 70 @combinations.generate(combinations.combine(trainable=[True, False])) 71 def testTrainable(self, trainable): 72 v = self.create_variable(trainable=trainable) 73 self.assertEqual(v.trainable, trainable) 74 75 @combinations.generate( 76 combinations.combine(synchronization=[ 77 variables_lib.VariableSynchronization.ON_READ, 78 variables_lib.VariableSynchronization.ON_WRITE, 79 variables_lib.VariableSynchronization.AUTO, 80 variables_lib.VariableSynchronization.NONE, 81 ])) 82 def testSynchronization(self, synchronization): 83 v = self.create_variable(synchronization=synchronization) 84 self.assertEqual(v.synchronization, synchronization) 85 86 @combinations.generate( 87 combinations.combine(aggregation=[ 88 variables_lib.VariableAggregation.MEAN, 89 variables_lib.VariableAggregation.SUM, 90 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 91 variables_lib.VariableAggregation.NONE, 92 ])) 93 def testAggregation(self, aggregation): 94 v = self.create_variable(aggregation=aggregation) 95 self.assertEqual(v.aggregation, aggregation) 96 97 @combinations.generate(combinations.combine(mode="graph")) 98 def testEval(self): 99 v = self.create_variable(1.) 100 with self.cached_session(): 101 self.evaluate(variables_lib.global_variables_initializer()) 102 self.assertEqual(v.eval(), 1.) 103 104 def testInitialValueEager(self): 105 v = self.create_variable(1.) 106 with self.assertRaises(RuntimeError): 107 v.initial_value # pylint: disable=pointless-statement 108 109 @combinations.generate(combinations.combine(mode="graph")) 110 def testInitialValueGraph(self): 111 v = self.create_variable(1.) 112 self.assertEqual(self.evaluate(v.initial_value), 1.) 113 114 def testConstraint(self): 115 v = self.create_variable(constraint=lambda x: x + 1.) 116 self.assertEqual(v.constraint(1.), 2.) 117 118 def testDenseUpdate(self): 119 v = self.create_variable(1.) 120 self.assertEqual( 121 v.assign(2., use_locking=True, name="assign", read_value=True), 2.) 122 self.assertIsNone(v.assign(3., read_value=False)) 123 self.assertEqual(v, 3.) 124 self.assertEqual( 125 v.assign_add(1., use_locking=True, name="assign_add", read_value=True), 126 4.) 127 self.assertIsNone(v.assign_add(1., read_value=False)) 128 self.assertEqual(v, 5.) 129 self.assertEqual( 130 v.assign_sub(1., use_locking=True, name="assign_sub", read_value=True), 131 4.) 132 self.assertIsNone(v.assign_sub(1., read_value=False)) 133 self.assertEqual(v, 3.) 134 135 @def_function.function 136 def f(): 137 self.assertIsInstance(v.assign(1., read_value=False), ops.Operation) 138 self.assertIsInstance(v.assign_add(1., read_value=False), ops.Operation) 139 self.assertIsInstance(v.assign_sub(1., read_value=False), ops.Operation) 140 141 f() 142 143 def testSparseUpdate(self): 144 v = self.create_variable([0., 0., 0.]) 145 self.assertAllEqual( 146 v.scatter_add( 147 _make_index_slices(values=[1., 2.], indices=[0, 2]), 148 use_locking=True, 149 name="add"), [1., 0., 2.]) 150 self.assertAllEqual( 151 v.scatter_div( 152 _make_index_slices(values=[4., 2.], indices=[0, 2]), 153 use_locking=True, 154 name="div"), [0.25, 0., 1.]) 155 self.assertAllEqual( 156 v.scatter_max( 157 _make_index_slices(values=[1., 0.5], indices=[1, 2]), 158 use_locking=True, 159 name="max"), [0.25, 1., 1.]) 160 self.assertAllEqual( 161 v.scatter_min( 162 _make_index_slices(values=[1., 0.5], indices=[0, 1]), 163 use_locking=True, 164 name="min"), [0.25, 0.5, 1.]) 165 self.assertAllEqual( 166 v.scatter_mul( 167 _make_index_slices(values=[2., 0.5], indices=[0, 1]), 168 use_locking=True, 169 name="mul"), [0.5, 0.25, 1.]) 170 self.assertAllEqual( 171 v.scatter_sub( 172 _make_index_slices(values=[2., 0.5], indices=[0, 1]), 173 use_locking=True, 174 name="sub"), [-1.5, -0.25, 1.]) 175 self.assertAllEqual( 176 v.scatter_update( 177 _make_index_slices(values=[2., 0.5], indices=[0, 1]), 178 use_locking=True, 179 name="update"), [2., 0.5, 1.]) 180 self.assertAllEqual( 181 v.batch_scatter_update( 182 _make_index_slices(values=[1., 1.5], indices=[0, 1]), 183 use_locking=True, 184 name="update"), [1., 1.5, 1.]) 185 186 def testSparseNdUpdate(self): 187 v = self.create_variable([0., 0., 0., 0.]) 188 self.assertAllEqual( 189 v.scatter_nd_sub([[3], [1]], [1., 2.], name="sub"), [0., -2., 0., -1.]) 190 self.assertAllEqual( 191 v.scatter_nd_add([[2], [0]], [1., 2.], name="add"), [2., -2., 1., -1.]) 192 self.assertAllEqual( 193 v.scatter_nd_update([[1], [3]], [3., 3.], name="update"), 194 [2., 3., 1., 3.]) 195 196 def testSparseRead(self): 197 v = self.create_variable([[1., 2.], [3., 4.]]) 198 self.assertAllEqual( 199 v.sparse_read([1, 0], name="read"), [[3., 4.], [1., 2.]]) 200 self.assertAllEqual( 201 v.gather_nd([[1, 0], [0, 1]], name="gather_nd"), [3., 2.]) 202 203 def testTensorConversion(self): 204 v = self.create_variable([1.]) 205 self.assertEqual(ops.convert_to_tensor(v), [1.]) 206 207 def testHash(self): 208 v = self.create_variable() 209 w = self.create_variable() 210 d = {} 211 with self.assertRaises(TypeError): 212 d[v] = 1 213 d[v.ref()] = 1 214 self.assertEqual(d[v.ref()], 1) 215 self.assertNotIn(w.ref(), d) 216 217 @combinations.generate(combinations.combine(mode="graph")) 218 def testHashGraph(self): 219 v = self.create_variable() 220 w = self.create_variable() 221 d = {v: 1} 222 self.assertEqual(d[v], 1) 223 self.assertNotIn(w, d) 224 225 def testEquality(self): 226 v = self.create_variable(1.) 227 w = self.create_variable(2.) 228 x = self.create_variable(1.) 229 self.assertEqual(v, x) 230 self.assertNotEqual(v, w) 231 232 @combinations.generate(combinations.combine(mode="graph")) 233 def testEqualityGraph(self): 234 # In legacy graph mode, tensor equality is object equality 235 v = self.create_variable(1.) 236 w = self.create_variable(1.) 237 self.assertNotEqual(v, w) 238 self.assertEqual(v, v) 239 240 def testIteration(self): 241 v = self.create_variable([1.]) 242 self.assertEqual([1.], list(iter(v))) 243 244 def testProperties(self): 245 v = self.create_variable() 246 self.assertIsInstance(v.name, str) 247 # _shared_name is also part of the interface. E.g. it's used in optimizer to 248 # determine slot variable key. 249 self.assertIsInstance(v._shared_name, str) 250 self.assertIsNone(v.initializer) 251 self.assertIsInstance(v.device, str) 252 self.assertEqual(v.dtype, dtypes.float32) 253 with self.assertRaises(AttributeError): 254 v.op # pylint: disable=pointless-statement 255 with self.assertRaises(AttributeError): 256 v.graph # pylint: disable=pointless-statement 257 258 @combinations.generate(combinations.combine(mode="graph")) 259 def testPropertiesGraph(self): 260 v = self.create_variable() 261 self.assertIsInstance(v.initializer, ops.Operation) 262 self.assertIsInstance(v.op, ops.Operation) 263 self.assertIsInstance(v.graph, ops.Graph) 264 265 def testProtoConversion(self): 266 # to_proto and from_proto are not supported. 267 v = self.create_variable([1, 2]) 268 with self.assertRaises(TypeError): 269 v.to_proto() 270 with self.assertRaises(TypeError): 271 v.from_proto(variable_def=None) 272 273 def testSaveSliceInfo(self): 274 v = self.create_variable() 275 slice_info = variables_lib.Variable.SaveSliceInfo() 276 v._set_save_slice_info(slice_info) 277 self.assertIs(v._get_save_slice_info(), slice_info) 278 # Some code accesses _save_slice_info directly without using the getter. 279 self.assertIs(v._save_slice_info, slice_info) 280 281 def testOperatorOverride(self): 282 v = self.create_variable(7) 283 self.assertEqual(v + 1, 8) 284 self.assertEqual(3 + v, 10) 285 self.assertEqual(v + v, 14) 286 self.assertEqual(v - 2, 5) 287 self.assertEqual(13 - v, 6) 288 self.assertEqual(v - v, 0) 289 self.assertEqual(v * 2, 14) 290 self.assertEqual(3 * v, 21) 291 self.assertEqual(v * v, 49) 292 self.assertEqual(v / 2, 3.5) 293 self.assertEqual(14 / v, 2.) 294 self.assertEqual(v // 2, 3) 295 self.assertEqual(15 // v, 2) 296 self.assertEqual(v % 2, 1) 297 self.assertEqual(16 % v, 2) 298 # pylint: disable=g-generic-assert 299 self.assertTrue(v < 12) 300 self.assertTrue(v <= 12) 301 self.assertFalse(v > 12) 302 self.assertFalse(v >= 12) 303 self.assertFalse(12 < v) 304 self.assertFalse(12 <= v) 305 self.assertTrue(12 > v) 306 self.assertTrue(12 >= v) 307 # pylint: enable=g-generic-assert 308 self.assertEqual(v & 3, 3) 309 self.assertEqual(11 & v, 3) 310 self.assertEqual(v | 8, 15) 311 self.assertEqual(16 | v, 23) 312 self.assertEqual(v ^ 3, 4) 313 self.assertEqual(11 ^ v, 12) 314 self.assertEqual(pow(v, 3), 343) 315 # TODO(b/178748613): pow(v, 3, 10) fails. 316 self.assertEqual(pow(2, v), 128) 317 self.assertEqual(-v, -7) 318 self.assertEqual(~v, ~7) 319 self.assertEqual(abs(v), 7) 320 321 def testSlice(self): 322 v = self.create_variable([1., 2., 3.]) 323 self.assertEqual(v[1], 2.) 324 v[2].assign(4.) 325 self.assertAllEqual(v, [1., 2., 4.]) 326 327 # ==== End Variable interface === 328 329 # ==== Begin ResourceVariable interface === 330 def testHandle(self): 331 v = self.create_variable() 332 self.assertIsInstance(v.handle, ops.Tensor) 333 self.assertEqual(v.handle.dtype, dtypes.resource) 334 335 def testInGraphMode(self): 336 # This is protected but used in a lot of places internally. 337 v = self.create_variable() 338 self.assertFalse(v._in_graph_mode) 339 340 def testUniqueId(self): 341 # This is used in optimizer as part of slot variable key. 342 v = self.create_variable() 343 w = self.create_variable() 344 self.assertNotEqual(v._unique_id, w._unique_id) 345 346 def testIsResourceVariable(self): 347 v = self.create_variable() 348 self.assertTrue(resource_variable_ops.is_resource_variable(v)) 349 # ==== End ResourceVariable interface === 350 351 @combinations.generate(combinations.combine(mode="graph")) 352 def testAsGraphElement(self): 353 g = ops.Graph() 354 with g.as_default(): 355 v = self.create_variable(1.) 356 g.finalize() 357 self.evaluate(v.initializer) 358 # _as_graph_element shouldn't create new operations. 359 self.assertEqual(self.evaluate(v._as_graph_element()), 1.) 360 361 362class DistributedVariableInterfaceTest(_VariableInterfaceTestBase): 363 364 def create_variable(self, initial_value=1., **kwargs): 365 variables = [] 366 for device in self.devices: 367 with ops.device(device): 368 variables.append( 369 variables_lib.Variable(initial_value, **kwargs)) 370 return values_v2.DistributedVariable(variables) 371 372 373# Prevent the base class from running. 374del _VariableInterfaceTestBase 375 376 377@combinations.generate( 378 combinations.combine( 379 strategy=[ 380 strategy_combinations.tpu_strategy, 381 strategy_combinations.mirrored_strategy_with_two_cpus, 382 strategy_combinations.mirrored_strategy_with_two_gpus, 383 ], 384 enable_packed_handle=[True, False], 385 tf_function=[combinations.tf_function, combinations.no_tf_function])) 386class DistributedVariableTest(test.TestCase, parameterized.TestCase): 387 388 def create_variable(self, strategy, initial_value, enable_packed_handle, 389 **kwargs): 390 variables = [] 391 for device in strategy.extended.parameter_devices: 392 with ops.device(device): 393 variables.append(variables_lib.Variable(initial_value, **kwargs)) 394 return values_v2.DistributedVariable( 395 variables, enable_packed_handle=enable_packed_handle) 396 397 def assertReplica(self, distributed_var, values): 398 for var, value in zip(distributed_var._variables, values): 399 self.assertAllEqual(var, value) 400 401 def testRead(self, strategy, enable_packed_handle, tf_function): 402 v = self.create_variable(strategy, 0., enable_packed_handle) 403 404 with ops.device(strategy.extended.parameter_devices[0]): 405 v.assign(1.) 406 with ops.device(strategy.extended.parameter_devices[1]): 407 v.assign(2.) 408 409 @tf_function 410 def read_device0(): 411 with ops.device(strategy.extended.parameter_devices[0]): 412 return v.read_value(), v.value() 413 414 @tf_function 415 def read_device1(): 416 with ops.device(strategy.extended.parameter_devices[1]): 417 return v.read_value(), v.value() 418 419 @tf_function 420 def read_other_device(): 421 with ops.device("CPU:0"): 422 return v.read_value(), v.value() 423 424 self.assertAllEqual(read_device0(), [1., 1.]) 425 self.assertAllEqual(read_device1(), [2., 2.]) 426 self.assertAllEqual(read_other_device(), [1., 1.]) 427 428 def testAssign(self, strategy, enable_packed_handle, tf_function): 429 v = self.create_variable(strategy, 0., enable_packed_handle) 430 431 @tf_function 432 def update_device0(): 433 with ops.device(strategy.extended.parameter_devices[0]): 434 v.assign(1.) 435 436 @tf_function 437 def update_device1(): 438 with ops.device(strategy.extended.parameter_devices[1]): 439 v.assign(2.) 440 441 update_device0() 442 update_device1() 443 self.assertReplica(v, [1., 2.]) 444 445 with ops.device("CPU:0"): 446 # Update the primary replica. 447 v.assign(3.) 448 self.assertReplica(v, [3., 2.]) 449 450 def testStrategyRun(self, strategy, enable_packed_handle, tf_function): 451 if (test_util.is_tpu_strategy(strategy) and 452 tf_function is combinations.no_tf_function): 453 self.skipTest("tpu doesn't support eager") 454 v = self.create_variable(strategy, 0., enable_packed_handle) 455 456 @tf_function 457 def update(per_replica): 458 v.assign(per_replica) 459 460 @tf_function 461 def read(): 462 return v.read_value() 463 464 strategy.run( 465 update, args=(test_util.create_per_replica(strategy, [1., 2.]),)) 466 self.assertReplica(v, [1., 2.]) 467 self.assertAllEqual( 468 test_util.gather(strategy, strategy.run(read)), [1., 2.]) 469 470 471def _make_index_slices(values, indices, dense_shape=None): 472 if dense_shape: 473 dense_shape = array_ops.identity(dense_shape) 474 return indexed_slices.IndexedSlices( 475 array_ops.identity(values), array_ops.identity(indices), dense_shape) 476 477 478if __name__ == "__main__": 479 test_util.main() 480