xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/mirrored_variable_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test MirroredVariable in MirroredStrategy and MultiWorkerMirroredStrategy."""
16
17from tensorflow.python.checkpoint import checkpoint as tracking_util
18from tensorflow.python.distribute import collective_all_reduce_strategy
19from tensorflow.python.distribute import combinations
20from tensorflow.python.distribute import distribute_utils
21from tensorflow.python.distribute import distribution_strategy_context as ds_context
22from tensorflow.python.distribute import strategy_combinations
23from tensorflow.python.distribute import values
24from tensorflow.python.eager import backprop
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.eager import test
28from tensorflow.python.framework import config
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import func_graph
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import custom_gradient
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import rnn
37from tensorflow.python.ops import rnn_cell_impl
38from tensorflow.python.ops import state_ops
39from tensorflow.python.ops import variable_scope
40from tensorflow.python.ops import variables
41from tensorflow.python.saved_model import load
42from tensorflow.python.saved_model import save
43
44
45def _replica_id():
46  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
47  if not isinstance(replica_id, ops.Tensor):
48    replica_id = constant_op.constant(replica_id)
49  return replica_id
50
51
52def _mimic_two_cpus():
53  cpus = config.list_physical_devices("CPU")
54
55  config.set_logical_device_configuration(cpus[0], [
56      context.LogicalDeviceConfiguration(),
57      context.LogicalDeviceConfiguration(),
58  ])
59
60
61@combinations.generate(
62    combinations.combine(
63        distribution=[
64            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
65            combinations.NamedDistribution(
66                "Collective2CPUs",
67                # pylint: disable=g-long-lambda
68                lambda: collective_all_reduce_strategy.
69                CollectiveAllReduceStrategy._from_local_devices((
70                    "/device:CPU:0", "/device:CPU:1")),
71                required_gpus=0)
72        ],
73        mode=["graph", "eager"]))
74class MirroredVariableCreationTest(test.TestCase):
75  """Base class that tests mirrored variable creator.
76
77  Currently it assumes all strategy objects have two replicas.
78  """
79
80  @classmethod
81  def setUpClass(cls):
82    _mimic_two_cpus()
83
84  def assertAllDifferent(self, objs):
85    for i in range(len(objs)):
86      for j in range(len(objs)):
87        if i == j:
88          continue
89        self.assertIsNot(objs[i], objs[j])
90
91  # TODO(priyag): Modify more tests to use this helper and check more
92  # properties.
93  def _test_mv_properties(self, var, name, strategy):
94    self.assertTrue(distribute_utils.is_mirrored(var))
95    self.assertEqual(name, var.name)
96    self.assertIs(strategy, var.distribute_strategy)
97    for i, d in enumerate(var._devices):
98      self.assertEqual(d, strategy.experimental_local_results(var)[i].device)
99      self.assertIs(
100          strategy,
101          strategy.experimental_local_results(var)[i]._distribute_strategy)  # pylint: disable=protected-access
102
103  def testVariableInFuncGraph(self, distribution):
104
105    def model_fn():
106      v = variable_scope.variable(2.0, name="bar")
107      ds_context.get_replica_context().merge_call(lambda _: _)
108      return v
109
110    with func_graph.FuncGraph("fg").as_default(), distribution.scope():
111      v1 = variable_scope.variable(1.0, name="foo")
112      v2 = distribution.extended.call_for_each_replica(model_fn)
113
114    self._test_mv_properties(v1, "foo:0", distribution)
115    self._test_mv_properties(v2, "bar:0", distribution)
116
117  def testVariableWithTensorInitialValueInFunction(self, distribution):
118    if not context.executing_eagerly():
119      self.skipTest("`tf.function` is an eager-only feature")
120
121    v = [None]
122
123    def model_fn():
124      if v[0] is None:
125        init_val = array_ops.zeros([])
126        v[0] = variables.Variable(init_val)
127      ds_context.get_replica_context().merge_call(lambda _: _)
128      return v[0]
129
130    @def_function.function(autograph=False)
131    def make_v1():
132      return distribution.experimental_local_results(
133          distribution.extended.call_for_each_replica(model_fn))
134
135    self.assertAllEqual([0, 0], make_v1())
136
137  def testSingleVariable(self, distribution):
138
139    def model_fn():
140      # This variable should be created only once across the threads because of
141      # special variable_creator functions used by
142      # `distribution.extended.call_for_each_replica`.
143      v = variable_scope.variable(1.0, name="foo")
144      ds_context.get_replica_context().merge_call(lambda _: _)
145      return v
146
147    with distribution.scope():
148      result = distribution.extended.call_for_each_replica(model_fn)
149      self._test_mv_properties(result, "foo:0", distribution)
150
151  def testUnnamedVariable(self, distribution):
152
153    def model_fn():
154      v = variable_scope.variable(1.0)
155      ds_context.get_replica_context().merge_call(lambda _: _)
156      return v
157
158    with distribution.scope():
159      result = distribution.extended.call_for_each_replica(model_fn)
160      self._test_mv_properties(result, "Variable:0", distribution)
161
162  def testMultipleVariables(self, distribution):
163
164    def model_fn():
165      vs = []
166      for i in range(5):
167        vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
168      ds_context.get_replica_context().merge_call(lambda _: _)
169      return vs
170
171    with distribution.scope():
172      result = distribution.extended.call_for_each_replica(model_fn)
173      for i, v in enumerate(result):
174        self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
175
176  def testMultipleVariablesWithSameCanonicalName(self, distribution):
177
178    def model_fn():
179      vs = []
180      vs.append(variable_scope.variable(1.0, name="foo/bar"))
181      vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
182      vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
183      vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
184      ds_context.get_replica_context().merge_call(lambda _: _)
185      return vs
186
187    with distribution.scope():
188      result = distribution.extended.call_for_each_replica(model_fn)
189      for v in result:
190        self.assertTrue(distribute_utils.is_mirrored(v))
191      self.assertEqual(4, len(result))
192      self.assertEqual("foo/bar:0", result[0].name)
193      self.assertEqual("foo_1/bar:0", result[1].name)
194      self.assertEqual("foo_1/bar_1:0", result[2].name)
195      self.assertEqual("foo/bar_1:0", result[3].name)
196
197  def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
198
199    def model_fn():
200      replica_id = self.evaluate(_replica_id())
201      v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
202      ds_context.get_replica_context().merge_call(lambda _: _)
203      return v
204
205    with distribution.scope():
206      result = distribution.extended.call_for_each_replica(model_fn)
207      self.assertTrue(distribute_utils.is_mirrored(result))
208      # The resulting mirrored variable will use the name from the first device.
209      self.assertEqual("foo_0:0", result.name)
210
211  def testWithVariableAndVariableScope(self, distribution):
212
213    def model_fn():
214      v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
215      with variable_scope.variable_scope("common"):
216        v1 = variable_scope.variable(1.0, name="var1")
217        # This will pause the current thread, and execute the other thread.
218        ds_context.get_replica_context().merge_call(lambda _: _)
219        v2 = variable_scope.variable(
220            1.0,
221            name="var2",
222            synchronization=variable_scope.VariableSynchronization.ON_READ,
223            aggregation=variable_scope.VariableAggregation.SUM)
224        v3 = variable_scope.variable(
225            1.0,
226            name="var3",
227            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
228            aggregation=variable_scope.VariableAggregation.MEAN)
229
230      return v0, v1, v2, v3
231
232    with distribution.scope():
233      v = variable_scope.variable(1.0, name="var-main0")
234      self.assertEqual("var-main0:0", v.name)
235
236      result = distribution.extended.call_for_each_replica(model_fn)
237      self.assertEqual(4, len(result))
238      v0, v1, v2, v3 = result
239      self.assertTrue(distribute_utils.is_mirrored(v0))
240      self.assertEqual("var0:0", v0.name)
241      self.assertTrue(distribute_utils.is_mirrored(v1))
242      self.assertEqual("common/var1:0", v1.name)
243      self.assertTrue(distribute_utils.is_sync_on_read(v2))
244      self.assertEqual("common/var2:0", v2.name)
245      self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
246      self.assertTrue(distribute_utils.is_mirrored(v3))
247      self.assertEqual("common/var3:0", v3.name)
248      self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
249
250  def testWithGetVariableAndVariableScope(self, distribution):
251
252    def model_fn():
253      v0 = variable_scope.get_variable("var0", [1])
254      with variable_scope.variable_scope("common"):
255        v1 = variable_scope.get_variable("var1", [1])
256        # This will pause the current thread, and execute the other thread.
257        ds_context.get_replica_context().merge_call(lambda _: _)
258        v2 = variable_scope.get_variable(
259            "var2", [1],
260            synchronization=variable_scope.VariableSynchronization.ON_READ,
261            aggregation=variable_scope.VariableAggregation.SUM)
262        v3 = variable_scope.get_variable(
263            "var3", [1],
264            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
265            aggregation=variable_scope.VariableAggregation.MEAN)
266
267      return v0, v1, v2, v3
268
269    with distribution.scope():
270      with variable_scope.variable_scope("main"):
271        v = variable_scope.get_variable("var-main0", [1])
272        self.assertEqual("main/var-main0:0", v.name)
273
274        result = distribution.extended.call_for_each_replica(model_fn)
275        self.assertEqual(4, len(result))
276        v0, v1, v2, v3 = result
277        self.assertTrue(distribute_utils.is_mirrored(v0))
278        self.assertEqual("main/var0:0", v0.name)
279        self.assertTrue(distribute_utils.is_mirrored(v1))
280        self.assertEqual("main/common/var1:0", v1.name)
281        self.assertTrue(distribute_utils.is_sync_on_read(v2))
282        self.assertEqual("main/common/var2:0", v2.name)
283        self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
284        self.assertTrue(distribute_utils.is_mirrored(v3))
285        self.assertEqual("main/common/var3:0", v3.name)
286        self.assertEqual(variable_scope.VariableAggregation.MEAN,
287                         v3.aggregation)
288
289  def testOnlyFirstReplicaUpdatesVariables(self, distribution):
290
291    def create_fn():
292      aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
293      v0 = variable_scope.variable(
294          2.0,
295          name="on_read",
296          synchronization=variable_scope.VariableSynchronization.ON_READ,
297          aggregation=aggregation)
298      v1 = variable_scope.variable(
299          3.0,
300          name="on_write",
301          synchronization=variable_scope.VariableSynchronization.ON_WRITE,
302          aggregation=aggregation)
303      return v0, v1
304
305    with distribution.scope():
306      v0, v1 = distribution.extended.call_for_each_replica(create_fn)
307      self.evaluate(v0.initializer)
308      self.assertEqual(
309          2.0, self.evaluate(distribution.experimental_local_results(v0)[0]))
310      self.assertEqual(
311          2.0, self.evaluate(distribution.experimental_local_results(v0)[1]))
312      self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
313      self.evaluate(v1.initializer)
314      self.assertEqual(
315          3.0, self.evaluate(distribution.experimental_local_results(v1)[0]))
316      self.assertEqual(
317          3.0, self.evaluate(distribution.experimental_local_results(v1)[1]))
318      self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
319
320      def replica_id_plus_one():
321        return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
322
323      # Update using the assign_add member function.
324      def update_member_fn():
325        update0 = v0.assign_add(5.0 * replica_id_plus_one())
326        update1 = v1.assign_add(7.0 * replica_id_plus_one())
327        return update0, update1
328
329      update0a, update1a = distribution.extended.call_for_each_replica(
330          update_member_fn)
331
332      # Update "sync on read" variable.
333      self.evaluate(distribution.group(update0a))
334      local_results = self.evaluate(distribution.experimental_local_results(v0))
335      self.assertEqual(2.0 + 5.0, local_results[0])
336      # Writes are not synchronized for "sync on read" variables,
337      # so device[1] can end up with a different value.
338      self.assertEqual(2.0 + 2 * 5.0, local_results[1])
339      # Always reads from device 0.
340      self.assertEqual(2.0 + 5.0,
341                       self.evaluate(distribution.extended.read_var(v0)))
342
343      # Update "sync on write" variable.
344      self.evaluate(distribution.group(update1a))
345      local_results1 = self.evaluate(
346          distribution.experimental_local_results(v1))
347      self.assertEqual(3.0 + 7.0, local_results1[0])
348      # Writes are synchronized for v1, only the argument to assign_add on
349      # device[0] is used.
350      self.assertEqual(3.0 + 7.0, local_results1[1])
351      self.assertEqual(3.0 + 7.0,
352                       self.evaluate(distribution.extended.read_var(v1)))
353
354      # Update using state_ops.assign_add global function.
355      def update_state_ops_fn():
356        update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
357        update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
358        return update0, update1
359
360      update0b, update1b = distribution.extended.call_for_each_replica(
361          update_state_ops_fn)
362      self.evaluate(distribution.group(update0b))
363
364      # Update "sync on read" variable.
365      local_results = self.evaluate(distribution.experimental_local_results(v0))
366      self.assertEqual(2.0 + 5.0 + 11.0, local_results[0])
367      self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0, local_results[1])
368      self.assertEqual(2.0 + 5.0 + 11.0,
369                       self.evaluate(distribution.extended.read_var(v0)))
370
371      # Update "sync on write" variable.
372      self.evaluate(distribution.group(update1b))
373      local_results1 = self.evaluate(
374          distribution.experimental_local_results(v1))
375      self.assertEqual(3.0 + 7.0 + 13.0, local_results1[0])
376      self.assertEqual(3.0 + 7.0 + 13.0, local_results1[1])
377      self.assertEqual(3.0 + 7.0 + 13.0,
378                       self.evaluate(distribution.extended.read_var(v1)))
379
380  def testNoneSynchronizationWithGetVariable(self, distribution):
381    with distribution.scope():
382      with self.assertRaisesRegex(
383          ValueError, "`NONE` variable synchronization mode is not "
384          "supported with "):
385        variable_scope.get_variable(
386            "v", [1],
387            synchronization=variable_scope.VariableSynchronization.NONE)
388
389  def testNoneSynchronizationWithVariable(self, distribution):
390    with distribution.scope():
391      with self.assertRaisesRegex(
392          ValueError, "`NONE` variable synchronization mode is not "
393          "supported with "):
394        variable_scope.variable(
395            1.0,
396            name="v",
397            synchronization=variable_scope.VariableSynchronization.NONE)
398
399  def testInvalidSynchronizationWithVariable(self, distribution):
400    with distribution.scope():
401      with self.assertRaisesRegex(
402          ValueError, "Invalid variable synchronization mode: Invalid for "
403          "variable: v"):
404        variable_scope.variable(1.0, name="v", synchronization="Invalid")
405
406  def testInvalidAggregationWithGetVariable(self, distribution):
407    with distribution.scope():
408      with self.assertRaisesRegex(
409          ValueError, "Invalid variable aggregation mode: invalid for "
410          "variable: v"):
411        variable_scope.get_variable(
412            "v", [1],
413            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
414            aggregation="invalid")
415
416  def testInvalidAggregationWithVariable(self, distribution):
417    with distribution.scope():
418      with self.assertRaisesRegex(
419          ValueError, "Invalid variable aggregation mode: invalid for "
420          "variable: v"):
421        variable_scope.variable(
422            1.0,
423            name="v",
424            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
425            aggregation="invalid")
426
427  def testNonMatchingVariableCreation(self, distribution):
428
429    def model_fn(name):
430      v = variable_scope.variable(1.0, name=name)
431      ds_context.get_replica_context().merge_call(lambda _: _)
432      return v
433
434    with distribution.scope():
435      names = values.PerReplica(("foo", "bar"))
436      with self.assertRaises(RuntimeError):
437        _ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
438
439  def testSyncOnReadVariable(self, distribution):
440
441    all_v_sum = {}
442    all_v_mean = {}
443    components_sum = {}
444    components_mean = {}
445
446    def model_fn():
447      replica_id = self.evaluate(_replica_id())
448      v_sum = variable_scope.variable(
449          1.0,
450          synchronization=variable_scope.VariableSynchronization.ON_READ,
451          aggregation=variable_scope.VariableAggregation.SUM)
452      v_mean = variable_scope.variable(
453          4.0,
454          synchronization=variable_scope.VariableSynchronization.ON_READ,
455          aggregation=variable_scope.VariableAggregation.MEAN)
456      self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
457      self.assertTrue(distribute_utils.is_sync_on_read(v_mean))
458      updates = [
459          v_sum.assign_add(2.0 + replica_id),
460          v_mean.assign(6.0 * replica_id)
461      ]
462      all_v_sum[replica_id] = v_sum
463      all_v_mean[replica_id] = v_mean
464      c_sum = v_sum._get()
465      c_mean = v_mean._get()
466      components_sum[replica_id] = c_sum
467      components_mean[replica_id] = c_mean
468      self.assertIsNot(v_sum, c_sum)
469      self.assertIsNot(v_mean, c_mean)
470      return updates, v_sum, v_mean, c_sum, c_mean
471
472    with distribution.scope():
473      # Create "sum" and "mean" versions of SyncOnReadVariables.
474      ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
475          distribution.extended.call_for_each_replica(model_fn))
476      # Should see the same wrapping instance in all replicas.
477      self.assertIs(all_v_sum[0], ret_v_sum)
478      self.assertIs(all_v_mean[0], ret_v_mean)
479      self.assertIs(all_v_sum[0], all_v_sum[1])
480      self.assertIs(all_v_mean[0], all_v_mean[1])
481
482      # Regroup should recover the same wrapper.
483      self.assertIs(ret_v_sum, regrouped_sum)
484      self.assertIs(ret_v_mean, regrouped_mean)
485      self.assertIsNot(components_sum[0], components_sum[1])
486      self.assertIsNot(components_mean[0], components_mean[1])
487
488      # Apply updates
489      self.evaluate(variables.global_variables_initializer())
490      self.evaluate([
491          y for x in ret_ops  # pylint: disable=g-complex-comprehension
492          for y in distribution.experimental_local_results(x)
493      ])
494      expected_sum = 0.0
495      expected_mean = 0.0
496      for i, _ in enumerate(distribution.extended.worker_devices):
497        # Should see different values on different devices.
498        v_sum_value = self.evaluate(
499            distribution.experimental_local_results(ret_v_sum)[i].read_value())
500        v_mean_value = self.evaluate(
501            distribution.experimental_local_results(ret_v_mean)[i].read_value())
502        expected = i + 3.0
503        self.assertEqual(expected, v_sum_value)
504        expected_sum += expected
505        expected = i * 6.0
506        self.assertEqual(expected, v_mean_value)
507        expected_mean += expected
508      expected_mean /= len(distribution.extended.worker_devices)
509
510      # Without get(device), should return the value you get by
511      # applying the reduction across all replicas (whether you use
512      # read_var(), get(), or nothing).
513      self.assertEqual(expected_sum, self.evaluate(
514          distribution.extended.read_var(ret_v_sum)))
515      self.assertEqual(expected_mean, self.evaluate(
516          distribution.extended.read_var(ret_v_mean)))
517      self.assertEqual(expected_sum, self.evaluate(ret_v_sum._get()))
518      self.assertEqual(expected_mean, self.evaluate(ret_v_mean._get()))
519      self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
520      self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
521
522  # TODO(priyag): Update this test to work in eager mode as well.
523  def testDynamicRnnVariables(self, distribution):
524
525    def model_fn():
526      inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
527      cell_fw = rnn_cell_impl.LSTMCell(300)
528      cell_bw = rnn_cell_impl.LSTMCell(300)
529      (outputs, _) = rnn.bidirectional_dynamic_rnn(
530          cell_fw, cell_bw, inputs, dtype=dtypes.float32)
531      return outputs
532
533    with context.graph_mode(), distribution.scope():
534      result = distribution.extended.call_for_each_replica(model_fn)
535      # Two variables are created by the RNN layer.
536      self.assertEqual(2, len(result))
537      for v in result:
538        self.assertIsInstance(v, values.DistributedValues)
539        _, v1 = distribution.experimental_local_results(v)
540        self.assertStartsWith(v1._op.name, "replica_1/")
541
542  def testSyncOnReadVariableUpdate(self, distribution):
543
544    def model_fn():
545      v_sum = variable_scope.variable(
546          1.0,
547          synchronization=variable_scope.VariableSynchronization.ON_READ,
548          aggregation=variable_scope.VariableAggregation.SUM)
549      self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
550      return v_sum
551
552    def update(var, value):
553      return var.assign(value)
554
555    with distribution.scope():
556      ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
557
558      # Initialize variables.
559      self.evaluate(variables.global_variables_initializer())
560      # Assert that the aggregated value of the sync on read var is the sum
561      # of the individual values before running the update ops.
562      self.assertEqual(
563          1.0,
564          self.evaluate(
565              distribution.experimental_local_results(ret_v_sum)
566              [0].read_value()))
567      self.assertEqual(2.0, self.evaluate(ret_v_sum))
568
569      # Apply updates.
570      update_ops = distribution.extended.update(
571          ret_v_sum, update, args=(5.0,), group=False)
572      self.evaluate(update_ops)
573      # Assert that the aggregated value of the sync on read vars is the sum
574      # of the individual values after running the update ops.
575      self.assertEqual(
576          5.0,
577          self.evaluate(
578              distribution.experimental_local_results(ret_v_sum)
579              [0].read_value()))
580      self.assertEqual(10.0, self.evaluate(ret_v_sum))
581
582  def testVarDistributeStrategy(self, distribution):
583    with distribution.scope():
584      mirrored = variable_scope.variable(1.0)
585      sync_on_read = variable_scope.variable(
586          1.0, synchronization=variable_scope.VariableSynchronization.ON_READ)
587      self.assertIs(distribution, mirrored.distribute_strategy)
588      self.assertIs(distribution, sync_on_read.distribute_strategy)
589
590  def testInitializer(self, distribution, mode):
591    if mode == "graph":
592      self.skipTest("Skip graph mode")
593
594    temp_dir = self.get_temp_dir()
595
596    class Model(tracking_util.Checkpoint):
597
598      def __init__(self):
599        self._v = variables.Variable(1.0)
600
601    with distribution.scope():
602      m = Model()
603    save.save(m, temp_dir)
604
605    g = ops.Graph()
606    with g.as_default():
607      with distribution.scope():
608        load.load(temp_dir)
609
610      for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
611        self.assertIsNotNone(v.initializer)
612
613  def testCustomGradient(self, distribution):
614
615    class CustomModel:
616
617      def __init__(self):
618        self._v = variables.Variable(1.0)
619
620      def __call__(self):
621
622        @custom_gradient.recompute_grad
623        def _call():
624          return self._v + 1
625
626        return _call()
627
628    with distribution.scope():
629      model = CustomModel()
630
631      @def_function.function
632      def train_step():
633
634        def replica_step():
635          with backprop.GradientTape() as tape:
636            result = model()
637          return tape.gradient(result, [model._v])
638
639        return distribution.run(replica_step)
640
641    grads = distribution.experimental_local_results(train_step())
642    self.assertLen(grads, distribution.num_replicas_in_sync)
643
644
645if __name__ == "__main__":
646  test.main()
647