1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test MirroredVariable in MirroredStrategy and MultiWorkerMirroredStrategy.""" 16 17from tensorflow.python.checkpoint import checkpoint as tracking_util 18from tensorflow.python.distribute import collective_all_reduce_strategy 19from tensorflow.python.distribute import combinations 20from tensorflow.python.distribute import distribute_utils 21from tensorflow.python.distribute import distribution_strategy_context as ds_context 22from tensorflow.python.distribute import strategy_combinations 23from tensorflow.python.distribute import values 24from tensorflow.python.eager import backprop 25from tensorflow.python.eager import context 26from tensorflow.python.eager import def_function 27from tensorflow.python.eager import test 28from tensorflow.python.framework import config 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import func_graph 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import custom_gradient 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import rnn 37from tensorflow.python.ops import rnn_cell_impl 38from tensorflow.python.ops import state_ops 39from tensorflow.python.ops import variable_scope 40from tensorflow.python.ops import variables 41from tensorflow.python.saved_model import load 42from tensorflow.python.saved_model import save 43 44 45def _replica_id(): 46 replica_id = ds_context.get_replica_context().replica_id_in_sync_group 47 if not isinstance(replica_id, ops.Tensor): 48 replica_id = constant_op.constant(replica_id) 49 return replica_id 50 51 52def _mimic_two_cpus(): 53 cpus = config.list_physical_devices("CPU") 54 55 config.set_logical_device_configuration(cpus[0], [ 56 context.LogicalDeviceConfiguration(), 57 context.LogicalDeviceConfiguration(), 58 ]) 59 60 61@combinations.generate( 62 combinations.combine( 63 distribution=[ 64 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 65 combinations.NamedDistribution( 66 "Collective2CPUs", 67 # pylint: disable=g-long-lambda 68 lambda: collective_all_reduce_strategy. 69 CollectiveAllReduceStrategy._from_local_devices(( 70 "/device:CPU:0", "/device:CPU:1")), 71 required_gpus=0) 72 ], 73 mode=["graph", "eager"])) 74class MirroredVariableCreationTest(test.TestCase): 75 """Base class that tests mirrored variable creator. 76 77 Currently it assumes all strategy objects have two replicas. 78 """ 79 80 @classmethod 81 def setUpClass(cls): 82 _mimic_two_cpus() 83 84 def assertAllDifferent(self, objs): 85 for i in range(len(objs)): 86 for j in range(len(objs)): 87 if i == j: 88 continue 89 self.assertIsNot(objs[i], objs[j]) 90 91 # TODO(priyag): Modify more tests to use this helper and check more 92 # properties. 93 def _test_mv_properties(self, var, name, strategy): 94 self.assertTrue(distribute_utils.is_mirrored(var)) 95 self.assertEqual(name, var.name) 96 self.assertIs(strategy, var.distribute_strategy) 97 for i, d in enumerate(var._devices): 98 self.assertEqual(d, strategy.experimental_local_results(var)[i].device) 99 self.assertIs( 100 strategy, 101 strategy.experimental_local_results(var)[i]._distribute_strategy) # pylint: disable=protected-access 102 103 def testVariableInFuncGraph(self, distribution): 104 105 def model_fn(): 106 v = variable_scope.variable(2.0, name="bar") 107 ds_context.get_replica_context().merge_call(lambda _: _) 108 return v 109 110 with func_graph.FuncGraph("fg").as_default(), distribution.scope(): 111 v1 = variable_scope.variable(1.0, name="foo") 112 v2 = distribution.extended.call_for_each_replica(model_fn) 113 114 self._test_mv_properties(v1, "foo:0", distribution) 115 self._test_mv_properties(v2, "bar:0", distribution) 116 117 def testVariableWithTensorInitialValueInFunction(self, distribution): 118 if not context.executing_eagerly(): 119 self.skipTest("`tf.function` is an eager-only feature") 120 121 v = [None] 122 123 def model_fn(): 124 if v[0] is None: 125 init_val = array_ops.zeros([]) 126 v[0] = variables.Variable(init_val) 127 ds_context.get_replica_context().merge_call(lambda _: _) 128 return v[0] 129 130 @def_function.function(autograph=False) 131 def make_v1(): 132 return distribution.experimental_local_results( 133 distribution.extended.call_for_each_replica(model_fn)) 134 135 self.assertAllEqual([0, 0], make_v1()) 136 137 def testSingleVariable(self, distribution): 138 139 def model_fn(): 140 # This variable should be created only once across the threads because of 141 # special variable_creator functions used by 142 # `distribution.extended.call_for_each_replica`. 143 v = variable_scope.variable(1.0, name="foo") 144 ds_context.get_replica_context().merge_call(lambda _: _) 145 return v 146 147 with distribution.scope(): 148 result = distribution.extended.call_for_each_replica(model_fn) 149 self._test_mv_properties(result, "foo:0", distribution) 150 151 def testUnnamedVariable(self, distribution): 152 153 def model_fn(): 154 v = variable_scope.variable(1.0) 155 ds_context.get_replica_context().merge_call(lambda _: _) 156 return v 157 158 with distribution.scope(): 159 result = distribution.extended.call_for_each_replica(model_fn) 160 self._test_mv_properties(result, "Variable:0", distribution) 161 162 def testMultipleVariables(self, distribution): 163 164 def model_fn(): 165 vs = [] 166 for i in range(5): 167 vs.append(variable_scope.variable(1.0, name="foo" + str(i))) 168 ds_context.get_replica_context().merge_call(lambda _: _) 169 return vs 170 171 with distribution.scope(): 172 result = distribution.extended.call_for_each_replica(model_fn) 173 for i, v in enumerate(result): 174 self._test_mv_properties(v, "foo" + str(i) + ":0", distribution) 175 176 def testMultipleVariablesWithSameCanonicalName(self, distribution): 177 178 def model_fn(): 179 vs = [] 180 vs.append(variable_scope.variable(1.0, name="foo/bar")) 181 vs.append(variable_scope.variable(1.0, name="foo_1/bar")) 182 vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) 183 vs.append(variable_scope.variable(1.0, name="foo/bar_1")) 184 ds_context.get_replica_context().merge_call(lambda _: _) 185 return vs 186 187 with distribution.scope(): 188 result = distribution.extended.call_for_each_replica(model_fn) 189 for v in result: 190 self.assertTrue(distribute_utils.is_mirrored(v)) 191 self.assertEqual(4, len(result)) 192 self.assertEqual("foo/bar:0", result[0].name) 193 self.assertEqual("foo_1/bar:0", result[1].name) 194 self.assertEqual("foo_1/bar_1:0", result[2].name) 195 self.assertEqual("foo/bar_1:0", result[3].name) 196 197 def testVariableWithSameCanonicalNameAcrossThreads(self, distribution): 198 199 def model_fn(): 200 replica_id = self.evaluate(_replica_id()) 201 v = variable_scope.variable(1.0, name="foo_" + str(replica_id)) 202 ds_context.get_replica_context().merge_call(lambda _: _) 203 return v 204 205 with distribution.scope(): 206 result = distribution.extended.call_for_each_replica(model_fn) 207 self.assertTrue(distribute_utils.is_mirrored(result)) 208 # The resulting mirrored variable will use the name from the first device. 209 self.assertEqual("foo_0:0", result.name) 210 211 def testWithVariableAndVariableScope(self, distribution): 212 213 def model_fn(): 214 v0 = variable_scope.variable(1.0, name="var0", aggregation=None) 215 with variable_scope.variable_scope("common"): 216 v1 = variable_scope.variable(1.0, name="var1") 217 # This will pause the current thread, and execute the other thread. 218 ds_context.get_replica_context().merge_call(lambda _: _) 219 v2 = variable_scope.variable( 220 1.0, 221 name="var2", 222 synchronization=variable_scope.VariableSynchronization.ON_READ, 223 aggregation=variable_scope.VariableAggregation.SUM) 224 v3 = variable_scope.variable( 225 1.0, 226 name="var3", 227 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 228 aggregation=variable_scope.VariableAggregation.MEAN) 229 230 return v0, v1, v2, v3 231 232 with distribution.scope(): 233 v = variable_scope.variable(1.0, name="var-main0") 234 self.assertEqual("var-main0:0", v.name) 235 236 result = distribution.extended.call_for_each_replica(model_fn) 237 self.assertEqual(4, len(result)) 238 v0, v1, v2, v3 = result 239 self.assertTrue(distribute_utils.is_mirrored(v0)) 240 self.assertEqual("var0:0", v0.name) 241 self.assertTrue(distribute_utils.is_mirrored(v1)) 242 self.assertEqual("common/var1:0", v1.name) 243 self.assertTrue(distribute_utils.is_sync_on_read(v2)) 244 self.assertEqual("common/var2:0", v2.name) 245 self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) 246 self.assertTrue(distribute_utils.is_mirrored(v3)) 247 self.assertEqual("common/var3:0", v3.name) 248 self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation) 249 250 def testWithGetVariableAndVariableScope(self, distribution): 251 252 def model_fn(): 253 v0 = variable_scope.get_variable("var0", [1]) 254 with variable_scope.variable_scope("common"): 255 v1 = variable_scope.get_variable("var1", [1]) 256 # This will pause the current thread, and execute the other thread. 257 ds_context.get_replica_context().merge_call(lambda _: _) 258 v2 = variable_scope.get_variable( 259 "var2", [1], 260 synchronization=variable_scope.VariableSynchronization.ON_READ, 261 aggregation=variable_scope.VariableAggregation.SUM) 262 v3 = variable_scope.get_variable( 263 "var3", [1], 264 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 265 aggregation=variable_scope.VariableAggregation.MEAN) 266 267 return v0, v1, v2, v3 268 269 with distribution.scope(): 270 with variable_scope.variable_scope("main"): 271 v = variable_scope.get_variable("var-main0", [1]) 272 self.assertEqual("main/var-main0:0", v.name) 273 274 result = distribution.extended.call_for_each_replica(model_fn) 275 self.assertEqual(4, len(result)) 276 v0, v1, v2, v3 = result 277 self.assertTrue(distribute_utils.is_mirrored(v0)) 278 self.assertEqual("main/var0:0", v0.name) 279 self.assertTrue(distribute_utils.is_mirrored(v1)) 280 self.assertEqual("main/common/var1:0", v1.name) 281 self.assertTrue(distribute_utils.is_sync_on_read(v2)) 282 self.assertEqual("main/common/var2:0", v2.name) 283 self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) 284 self.assertTrue(distribute_utils.is_mirrored(v3)) 285 self.assertEqual("main/common/var3:0", v3.name) 286 self.assertEqual(variable_scope.VariableAggregation.MEAN, 287 v3.aggregation) 288 289 def testOnlyFirstReplicaUpdatesVariables(self, distribution): 290 291 def create_fn(): 292 aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA 293 v0 = variable_scope.variable( 294 2.0, 295 name="on_read", 296 synchronization=variable_scope.VariableSynchronization.ON_READ, 297 aggregation=aggregation) 298 v1 = variable_scope.variable( 299 3.0, 300 name="on_write", 301 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 302 aggregation=aggregation) 303 return v0, v1 304 305 with distribution.scope(): 306 v0, v1 = distribution.extended.call_for_each_replica(create_fn) 307 self.evaluate(v0.initializer) 308 self.assertEqual( 309 2.0, self.evaluate(distribution.experimental_local_results(v0)[0])) 310 self.assertEqual( 311 2.0, self.evaluate(distribution.experimental_local_results(v0)[1])) 312 self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0))) 313 self.evaluate(v1.initializer) 314 self.assertEqual( 315 3.0, self.evaluate(distribution.experimental_local_results(v1)[0])) 316 self.assertEqual( 317 3.0, self.evaluate(distribution.experimental_local_results(v1)[1])) 318 self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1))) 319 320 def replica_id_plus_one(): 321 return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) 322 323 # Update using the assign_add member function. 324 def update_member_fn(): 325 update0 = v0.assign_add(5.0 * replica_id_plus_one()) 326 update1 = v1.assign_add(7.0 * replica_id_plus_one()) 327 return update0, update1 328 329 update0a, update1a = distribution.extended.call_for_each_replica( 330 update_member_fn) 331 332 # Update "sync on read" variable. 333 self.evaluate(distribution.group(update0a)) 334 local_results = self.evaluate(distribution.experimental_local_results(v0)) 335 self.assertEqual(2.0 + 5.0, local_results[0]) 336 # Writes are not synchronized for "sync on read" variables, 337 # so device[1] can end up with a different value. 338 self.assertEqual(2.0 + 2 * 5.0, local_results[1]) 339 # Always reads from device 0. 340 self.assertEqual(2.0 + 5.0, 341 self.evaluate(distribution.extended.read_var(v0))) 342 343 # Update "sync on write" variable. 344 self.evaluate(distribution.group(update1a)) 345 local_results1 = self.evaluate( 346 distribution.experimental_local_results(v1)) 347 self.assertEqual(3.0 + 7.0, local_results1[0]) 348 # Writes are synchronized for v1, only the argument to assign_add on 349 # device[0] is used. 350 self.assertEqual(3.0 + 7.0, local_results1[1]) 351 self.assertEqual(3.0 + 7.0, 352 self.evaluate(distribution.extended.read_var(v1))) 353 354 # Update using state_ops.assign_add global function. 355 def update_state_ops_fn(): 356 update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one()) 357 update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one()) 358 return update0, update1 359 360 update0b, update1b = distribution.extended.call_for_each_replica( 361 update_state_ops_fn) 362 self.evaluate(distribution.group(update0b)) 363 364 # Update "sync on read" variable. 365 local_results = self.evaluate(distribution.experimental_local_results(v0)) 366 self.assertEqual(2.0 + 5.0 + 11.0, local_results[0]) 367 self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0, local_results[1]) 368 self.assertEqual(2.0 + 5.0 + 11.0, 369 self.evaluate(distribution.extended.read_var(v0))) 370 371 # Update "sync on write" variable. 372 self.evaluate(distribution.group(update1b)) 373 local_results1 = self.evaluate( 374 distribution.experimental_local_results(v1)) 375 self.assertEqual(3.0 + 7.0 + 13.0, local_results1[0]) 376 self.assertEqual(3.0 + 7.0 + 13.0, local_results1[1]) 377 self.assertEqual(3.0 + 7.0 + 13.0, 378 self.evaluate(distribution.extended.read_var(v1))) 379 380 def testNoneSynchronizationWithGetVariable(self, distribution): 381 with distribution.scope(): 382 with self.assertRaisesRegex( 383 ValueError, "`NONE` variable synchronization mode is not " 384 "supported with "): 385 variable_scope.get_variable( 386 "v", [1], 387 synchronization=variable_scope.VariableSynchronization.NONE) 388 389 def testNoneSynchronizationWithVariable(self, distribution): 390 with distribution.scope(): 391 with self.assertRaisesRegex( 392 ValueError, "`NONE` variable synchronization mode is not " 393 "supported with "): 394 variable_scope.variable( 395 1.0, 396 name="v", 397 synchronization=variable_scope.VariableSynchronization.NONE) 398 399 def testInvalidSynchronizationWithVariable(self, distribution): 400 with distribution.scope(): 401 with self.assertRaisesRegex( 402 ValueError, "Invalid variable synchronization mode: Invalid for " 403 "variable: v"): 404 variable_scope.variable(1.0, name="v", synchronization="Invalid") 405 406 def testInvalidAggregationWithGetVariable(self, distribution): 407 with distribution.scope(): 408 with self.assertRaisesRegex( 409 ValueError, "Invalid variable aggregation mode: invalid for " 410 "variable: v"): 411 variable_scope.get_variable( 412 "v", [1], 413 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 414 aggregation="invalid") 415 416 def testInvalidAggregationWithVariable(self, distribution): 417 with distribution.scope(): 418 with self.assertRaisesRegex( 419 ValueError, "Invalid variable aggregation mode: invalid for " 420 "variable: v"): 421 variable_scope.variable( 422 1.0, 423 name="v", 424 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 425 aggregation="invalid") 426 427 def testNonMatchingVariableCreation(self, distribution): 428 429 def model_fn(name): 430 v = variable_scope.variable(1.0, name=name) 431 ds_context.get_replica_context().merge_call(lambda _: _) 432 return v 433 434 with distribution.scope(): 435 names = values.PerReplica(("foo", "bar")) 436 with self.assertRaises(RuntimeError): 437 _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) 438 439 def testSyncOnReadVariable(self, distribution): 440 441 all_v_sum = {} 442 all_v_mean = {} 443 components_sum = {} 444 components_mean = {} 445 446 def model_fn(): 447 replica_id = self.evaluate(_replica_id()) 448 v_sum = variable_scope.variable( 449 1.0, 450 synchronization=variable_scope.VariableSynchronization.ON_READ, 451 aggregation=variable_scope.VariableAggregation.SUM) 452 v_mean = variable_scope.variable( 453 4.0, 454 synchronization=variable_scope.VariableSynchronization.ON_READ, 455 aggregation=variable_scope.VariableAggregation.MEAN) 456 self.assertTrue(distribute_utils.is_sync_on_read(v_sum)) 457 self.assertTrue(distribute_utils.is_sync_on_read(v_mean)) 458 updates = [ 459 v_sum.assign_add(2.0 + replica_id), 460 v_mean.assign(6.0 * replica_id) 461 ] 462 all_v_sum[replica_id] = v_sum 463 all_v_mean[replica_id] = v_mean 464 c_sum = v_sum._get() 465 c_mean = v_mean._get() 466 components_sum[replica_id] = c_sum 467 components_mean[replica_id] = c_mean 468 self.assertIsNot(v_sum, c_sum) 469 self.assertIsNot(v_mean, c_mean) 470 return updates, v_sum, v_mean, c_sum, c_mean 471 472 with distribution.scope(): 473 # Create "sum" and "mean" versions of SyncOnReadVariables. 474 ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( 475 distribution.extended.call_for_each_replica(model_fn)) 476 # Should see the same wrapping instance in all replicas. 477 self.assertIs(all_v_sum[0], ret_v_sum) 478 self.assertIs(all_v_mean[0], ret_v_mean) 479 self.assertIs(all_v_sum[0], all_v_sum[1]) 480 self.assertIs(all_v_mean[0], all_v_mean[1]) 481 482 # Regroup should recover the same wrapper. 483 self.assertIs(ret_v_sum, regrouped_sum) 484 self.assertIs(ret_v_mean, regrouped_mean) 485 self.assertIsNot(components_sum[0], components_sum[1]) 486 self.assertIsNot(components_mean[0], components_mean[1]) 487 488 # Apply updates 489 self.evaluate(variables.global_variables_initializer()) 490 self.evaluate([ 491 y for x in ret_ops # pylint: disable=g-complex-comprehension 492 for y in distribution.experimental_local_results(x) 493 ]) 494 expected_sum = 0.0 495 expected_mean = 0.0 496 for i, _ in enumerate(distribution.extended.worker_devices): 497 # Should see different values on different devices. 498 v_sum_value = self.evaluate( 499 distribution.experimental_local_results(ret_v_sum)[i].read_value()) 500 v_mean_value = self.evaluate( 501 distribution.experimental_local_results(ret_v_mean)[i].read_value()) 502 expected = i + 3.0 503 self.assertEqual(expected, v_sum_value) 504 expected_sum += expected 505 expected = i * 6.0 506 self.assertEqual(expected, v_mean_value) 507 expected_mean += expected 508 expected_mean /= len(distribution.extended.worker_devices) 509 510 # Without get(device), should return the value you get by 511 # applying the reduction across all replicas (whether you use 512 # read_var(), get(), or nothing). 513 self.assertEqual(expected_sum, self.evaluate( 514 distribution.extended.read_var(ret_v_sum))) 515 self.assertEqual(expected_mean, self.evaluate( 516 distribution.extended.read_var(ret_v_mean))) 517 self.assertEqual(expected_sum, self.evaluate(ret_v_sum._get())) 518 self.assertEqual(expected_mean, self.evaluate(ret_v_mean._get())) 519 self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) 520 self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) 521 522 # TODO(priyag): Update this test to work in eager mode as well. 523 def testDynamicRnnVariables(self, distribution): 524 525 def model_fn(): 526 inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) 527 cell_fw = rnn_cell_impl.LSTMCell(300) 528 cell_bw = rnn_cell_impl.LSTMCell(300) 529 (outputs, _) = rnn.bidirectional_dynamic_rnn( 530 cell_fw, cell_bw, inputs, dtype=dtypes.float32) 531 return outputs 532 533 with context.graph_mode(), distribution.scope(): 534 result = distribution.extended.call_for_each_replica(model_fn) 535 # Two variables are created by the RNN layer. 536 self.assertEqual(2, len(result)) 537 for v in result: 538 self.assertIsInstance(v, values.DistributedValues) 539 _, v1 = distribution.experimental_local_results(v) 540 self.assertStartsWith(v1._op.name, "replica_1/") 541 542 def testSyncOnReadVariableUpdate(self, distribution): 543 544 def model_fn(): 545 v_sum = variable_scope.variable( 546 1.0, 547 synchronization=variable_scope.VariableSynchronization.ON_READ, 548 aggregation=variable_scope.VariableAggregation.SUM) 549 self.assertTrue(distribute_utils.is_sync_on_read(v_sum)) 550 return v_sum 551 552 def update(var, value): 553 return var.assign(value) 554 555 with distribution.scope(): 556 ret_v_sum = distribution.extended.call_for_each_replica(model_fn) 557 558 # Initialize variables. 559 self.evaluate(variables.global_variables_initializer()) 560 # Assert that the aggregated value of the sync on read var is the sum 561 # of the individual values before running the update ops. 562 self.assertEqual( 563 1.0, 564 self.evaluate( 565 distribution.experimental_local_results(ret_v_sum) 566 [0].read_value())) 567 self.assertEqual(2.0, self.evaluate(ret_v_sum)) 568 569 # Apply updates. 570 update_ops = distribution.extended.update( 571 ret_v_sum, update, args=(5.0,), group=False) 572 self.evaluate(update_ops) 573 # Assert that the aggregated value of the sync on read vars is the sum 574 # of the individual values after running the update ops. 575 self.assertEqual( 576 5.0, 577 self.evaluate( 578 distribution.experimental_local_results(ret_v_sum) 579 [0].read_value())) 580 self.assertEqual(10.0, self.evaluate(ret_v_sum)) 581 582 def testVarDistributeStrategy(self, distribution): 583 with distribution.scope(): 584 mirrored = variable_scope.variable(1.0) 585 sync_on_read = variable_scope.variable( 586 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ) 587 self.assertIs(distribution, mirrored.distribute_strategy) 588 self.assertIs(distribution, sync_on_read.distribute_strategy) 589 590 def testInitializer(self, distribution, mode): 591 if mode == "graph": 592 self.skipTest("Skip graph mode") 593 594 temp_dir = self.get_temp_dir() 595 596 class Model(tracking_util.Checkpoint): 597 598 def __init__(self): 599 self._v = variables.Variable(1.0) 600 601 with distribution.scope(): 602 m = Model() 603 save.save(m, temp_dir) 604 605 g = ops.Graph() 606 with g.as_default(): 607 with distribution.scope(): 608 load.load(temp_dir) 609 610 for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES): 611 self.assertIsNotNone(v.initializer) 612 613 def testCustomGradient(self, distribution): 614 615 class CustomModel: 616 617 def __init__(self): 618 self._v = variables.Variable(1.0) 619 620 def __call__(self): 621 622 @custom_gradient.recompute_grad 623 def _call(): 624 return self._v + 1 625 626 return _call() 627 628 with distribution.scope(): 629 model = CustomModel() 630 631 @def_function.function 632 def train_step(): 633 634 def replica_step(): 635 with backprop.GradientTape() as tape: 636 result = model() 637 return tape.gradient(result, [model._v]) 638 639 return distribution.run(replica_step) 640 641 grads = distribution.experimental_local_results(train_step()) 642 self.assertLen(grads, distribution.num_replicas_in_sync) 643 644 645if __name__ == "__main__": 646 test.main() 647