1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for CollectiveAllReduceStrategy."""
16
17import copy
18import functools
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.core.protobuf import rewriter_config_pb2
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
27from tensorflow.python.distribute import collective_all_reduce_strategy
28from tensorflow.python.distribute import collective_util
29from tensorflow.python.distribute import combinations
30from tensorflow.python.distribute import distribute_lib
31from tensorflow.python.distribute import distribute_utils
32from tensorflow.python.distribute import distribution_strategy_context
33from tensorflow.python.distribute import multi_worker_test_base
34from tensorflow.python.distribute import multi_worker_util
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import strategy_combinations
37from tensorflow.python.distribute import strategy_test_lib
38from tensorflow.python.distribute import test_util
39from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
40from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
41from tensorflow.python.eager import context
42from tensorflow.python.framework import config as tf_config
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import device as tf_device
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import errors
47from tensorflow.python.framework import ops
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import gen_math_ops
50from tensorflow.python.ops import gradients
51from tensorflow.python.ops import init_ops
52from tensorflow.python.ops import init_ops_v2
53from tensorflow.python.ops import math_ops
54from tensorflow.python.ops import variable_scope
55from tensorflow.python.ops import variables
56from tensorflow.python.platform import test
57from tensorflow.python.tpu import tpu_strategy_util
58from tensorflow.python.training.server_lib import ClusterSpec
59
60
61CollectiveAllReduceStrategy = (
62    collective_all_reduce_strategy.CollectiveAllReduceStrategy)
63CollectiveAllReduceExtended = (
64    collective_all_reduce_strategy.CollectiveAllReduceExtended)
65_CollectiveAllReduceStrategyExperimental = (
66    collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental)
67
68
69# TODO(b/231630416): Create more tests to cover the case that strategy uses
70# different number of GPUs than the number of physical devices.
71def create_test_objects(cluster_spec=None,
72                        task_type=None,
73                        task_id=None,
74                        num_gpus=None,
75                        num_tpus=None):
76  if num_gpus is None:
77    num_gpus = context.num_gpus()
78  if num_tpus is None:
79    num_tpus = context.context().list_physical_devices('TPU')
80  if num_tpus:
81    tpu_strategy_util.initialize_tpu_system()
82
83  if cluster_spec and task_type and task_id is not None:
84    cluster_resolver = SimpleClusterResolver(
85        cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
86        task_type=task_type,
87        task_id=task_id,
88        num_accelerators={'GPU': num_gpus, 'TPU': num_tpus})
89    target = 'grpc://' + cluster_spec[task_type][task_id]
90  else:
91    cluster_resolver = SimpleClusterResolver(
92        ClusterSpec({}), num_accelerators={'GPU': num_gpus, 'TPU': num_tpus})
93    target = ''
94
95  strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
96      cluster_resolver=cluster_resolver)
97
98  return strategy, target
99
100
101class CollectiveAllReduceStrategyTestBase(
102    multi_worker_test_base.MultiWorkerTestBase):
103
104  def setUp(self):
105    # We use a different key_base for each test so that collective keys won't be
106    # reused.
107    CollectiveAllReduceStrategy._collective_key_base += 100000
108    super(CollectiveAllReduceStrategyTestBase, self).setUp()
109
110  def _get_test_object(self,
111                       task_type,
112                       task_id,
113                       num_gpus=0,
114                       num_tpus=0,
115                       use_devices_arg=False):
116    strategy, target = create_test_objects(
117        cluster_spec=self._cluster_spec,
118        task_type=task_type,
119        task_id=task_id,
120        num_gpus=num_gpus,
121        num_tpus=num_tpus)
122
123    if use_devices_arg and num_gpus > 0:
124      devices = ['GPU:%d' % i for i in range(num_gpus)]
125      # Temporary workaround to manually set the `_extended` field before device
126      # initialization is exposed as a public interface.
127      strategy._extended = CollectiveAllReduceExtended(
128          container_strategy=strategy,
129          cluster_resolver=None,
130          communication_options=collective_util.Options(),
131          devices=devices)
132      # Manually set the field since the workaround bypasses the base
133      # contructor, resulting in the absence of this field.
134      strategy._extended._retrace_functions_for_each_device = (num_gpus > 1)
135
136    return strategy, target
137
138  def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
139    distribution, master_target = self._get_test_object(task_type, task_id,
140                                                        num_gpus)
141    with ops.Graph().as_default(), \
142         self.cached_session(target=master_target) as sess, \
143         distribution.scope():
144      initializer = functools.partial(
145          init_ops_v2.GlorotUniform(), (1, 1), dtype=dtypes.float32)
146      kernel = variables.Variable(
147          initial_value=initializer,
148          name='gpu_%d/kernel' % distribution.extended._num_devices_per_worker,
149          trainable=True)
150
151      def loss_fn(x):
152        y = array_ops.reshape(
153            gen_math_ops.mat_mul(x, kernel), []) - constant_op.constant(1.)
154        return y * y
155
156      # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for
157      # multiple graphs (b/111216820).
158      def grad_fn(x):
159        loss = loss_fn(x)
160        var_list = (
161            variables.trainable_variables() + ops.get_collection(
162                ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
163        grads = gradients.gradients(loss, var_list)
164        ret = list(zip(grads, var_list))
165        return ret
166
167      def update(v, g):
168        return v.assign_sub(0.05 * g, use_locking=True)
169
170      one = constant_op.constant([[1.]])
171
172      def step():
173        """Perform one optimization step."""
174        # Run forward & backward to get gradients, variables list.
175        g_v = distribution.extended.call_for_each_replica(grad_fn, args=[one])
176        # Update the variables using the gradients and the update() function.
177        before_list = []
178        after_list = []
179        for g, v in g_v:
180          fetched = distribution.extended.read_var(v)
181          before_list.append(fetched)
182          with ops.control_dependencies([fetched]):
183            # TODO(yuefengz): support non-Mirrored variable as destinations.
184            g = distribution.extended.reduce_to(
185                reduce_util.ReduceOp.SUM, g, destinations=v)
186            with ops.control_dependencies(
187                distribution.extended.update(v, update, args=(g,),
188                                             group=False)):
189              after_list.append(distribution.extended.read_var(v))
190        return before_list, after_list
191
192      before_out, after_out = step()
193
194      if (distribution.extended._local_device_type == 'GPU' and
195          context.num_gpus() < distribution.extended._num_devices_per_worker):
196        return True
197
198      sess.run(variables.global_variables_initializer())
199
200      for i in range(10):
201        b, a = sess.run((before_out, after_out))
202        if i == 0:
203          before, = b
204        after, = a
205
206      error_before = abs(before - 1)
207      error_after = abs(after - 1)
208      # Error should go down
209      self.assertLess(error_after, error_before)
210
211  def _test_variable_initialization(self, task_type, task_id, num_gpus):
212    distribution, master_target = self._get_test_object(task_type, task_id,
213                                                        num_gpus)
214    with ops.Graph().as_default(), \
215         self.cached_session(target=master_target) as sess, \
216         distribution.scope():
217
218      def model_fn():
219        x = variable_scope.get_variable(
220            'x',
221            shape=(2, 3),
222            initializer=init_ops.random_uniform_initializer(
223                1.0, 10.0, dtype=dtypes.float32))
224        return array_ops.identity(x)
225
226      x = distribution.extended.call_for_each_replica(model_fn)
227      reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x, axis=None)
228      x = distribution.experimental_local_results(x)[0]
229
230      sess.run(variables.global_variables_initializer())
231
232      x_value, reduced_x_value = sess.run([x, reduced_x])
233      self.assertTrue(
234          np.allclose(x_value, reduced_x_value, atol=1e-5),
235          msg=('x_value = %r, reduced_x_value = %r' % (x_value,
236                                                       reduced_x_value)))
237
238  def _test_input_fn_iterator(self,
239                              task_type,
240                              task_id,
241                              num_gpus,
242                              input_fn,
243                              expected_values,
244                              test_reinitialize=True,
245                              ignore_order=False,
246                              use_devices_arg=False):
247    distribution, master_target = self._get_test_object(
248        task_type, task_id, num_gpus, use_devices_arg=use_devices_arg)
249    devices = distribution.extended.worker_devices
250
251    with ops.Graph().as_default(), \
252         self.cached_session(target=master_target) as sess:
253      iterator = distribution.make_input_fn_iterator(input_fn)
254      sess.run(iterator.initializer)
255
256      for expected_value in expected_values:
257        next_element = iterator.get_next()
258        computed_value = sess.run([distribute_utils.select_replica(
259            r, next_element) for r in range(len(devices))])
260        if ignore_order:
261          self.assertCountEqual(list(expected_value), list(computed_value))
262        else:
263          self.assertEqual(list(expected_value), list(computed_value))
264
265      with self.assertRaises(errors.OutOfRangeError):
266        next_element = iterator.get_next()
267        sess.run([distribute_utils.select_replica(r, next_element)
268                  for r in range(len(devices))])
269
270      # After re-initializing the iterator, should be able to iterate again.
271      if test_reinitialize:
272        sess.run(iterator.initializer)
273
274        for expected_value in expected_values:
275          next_element = iterator.get_next()
276          computed_value = sess.run([
277              distribute_utils.select_replica(r, next_element)
278              for r in range(len(devices))])
279          if ignore_order:
280            self.assertCountEqual(list(expected_value), list(computed_value))
281          else:
282            self.assertEqual(list(expected_value), list(computed_value))
283
284
285class DistributedCollectiveAllReduceStrategyTest(
286    CollectiveAllReduceStrategyTestBase,
287    strategy_test_lib.DistributionTestBase,
288    parameterized.TestCase):
289
290  @classmethod
291  def setUpClass(cls):
292    """Create a local cluster with 3 workers."""
293    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
294        num_workers=3, num_ps=0)
295
296  @combinations.generate(combinations.combine(mode=['graph']))
297  def test_num_replicas_in_sync(self):
298    distribution, _ = create_test_objects(
299        cluster_spec=self._cluster_spec,
300        task_type='worker',
301        task_id=0,
302        num_gpus=2)
303    num_workers = len(self._cluster_spec.get('chief', []) +
304                      self._cluster_spec.get('worker', []))
305    self.assertEqual(2 * num_workers,
306                     distribution.num_replicas_in_sync)
307
308  @combinations.generate(combinations.combine(
309      mode=['graph'],
310      prefetch_to_device=[None, True]))
311  def test_prefetch_to_device_dataset(self, prefetch_to_device):
312    distribution, _ = self._get_test_object(
313        task_type='worker', task_id=0, num_gpus=2)
314    if prefetch_to_device is None:
315      input_options = None
316    else:
317      input_options = distribute_lib.InputOptions(
318          experimental_fetch_to_device=prefetch_to_device)
319    dataset = dataset_ops.Dataset.range(100)
320    dataset = dataset.batch(distribution.num_replicas_in_sync)
321    dataset = distribution.experimental_distribute_dataset(
322        dataset, options=input_options)
323    if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
324      item = dataset.make_initializable_iterator().get_next()
325    else:
326      self.skipTest('unsupported test combination')
327    device_types = {
328        tf_device.DeviceSpec.from_string(tensor.device).device_type for
329        tensor in item.values}
330    self.assertAllEqual(list(device_types), ['GPU'])
331
332  @combinations.generate(combinations.combine(mode=['graph']))
333  def test_prefetch_to_host_dataset(self):
334    distribution, _ = self._get_test_object(
335        task_type='worker', task_id=0, num_gpus=2)
336    input_options = distribute_lib.InputOptions(
337        experimental_fetch_to_device=False)
338    dataset = dataset_ops.Dataset.range(100)
339    dataset = dataset.batch(distribution.num_replicas_in_sync)
340    dataset = distribution.experimental_distribute_dataset(
341        dataset, options=input_options)
342    if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
343      item = dataset.make_initializable_iterator().get_next()
344    else:
345      self.skipTest('unsupported test combination')
346    device_types = {
347        tf_device.DeviceSpec.from_string(tensor.device).device_type for
348        tensor in item.values}
349    self.assertAllEqual(list(device_types), ['CPU'])
350
351  @combinations.generate(
352      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
353  def testMinimizeLossGraph(self, required_gpus):
354    self._run_between_graph_clients(self._test_minimize_loss_graph,
355                                    self._cluster_spec, required_gpus)
356
357  @combinations.generate(
358      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
359  def testVariableInitialization(self, required_gpus):
360    self._run_between_graph_clients(
361        self._test_variable_initialization,
362        self._cluster_spec,
363        num_gpus=required_gpus)
364
365  @combinations.generate(
366      combinations.combine(
367          mode=['graph'], required_gpus=[0, 1, 2], use_dataset=[True, False]))
368  def testMakeInputFnIterator(self, required_gpus, use_dataset):
369    def _worker_fn(task_type, task_id, required_gpus):
370      if use_dataset:
371        fn = lambda: dataset_ops.Dataset.range(20)
372      else:
373        def fn():
374          dataset = dataset_ops.Dataset.range(20)
375          it = dataset_ops.make_one_shot_iterator(dataset)
376          return it.get_next
377      # We use CPU as the device when required_gpus = 0
378      devices_per_worker = max(1, required_gpus)
379      expected_values = [[i+j for j in range(devices_per_worker)]
380                         for i in range(0, 20, devices_per_worker)]
381
382      input_fn = self._input_fn_to_test_input_context(
383          fn,
384          expected_num_replicas_in_sync=3*devices_per_worker,
385          expected_num_input_pipelines=3,
386          expected_input_pipeline_id=task_id)
387      self._test_input_fn_iterator(
388          task_type,
389          task_id,
390          required_gpus,
391          input_fn,
392          expected_values,
393          test_reinitialize=use_dataset,
394          ignore_order=not use_dataset)
395
396    self._run_between_graph_clients(_worker_fn, self._cluster_spec,
397                                    required_gpus)
398
399  @combinations.generate(combinations.combine(mode=['graph']))
400  def testUpdateConfigProto(self):
401    strategy, _ = self._get_test_object(
402        task_type='worker', task_id=1, num_gpus=2)
403
404    config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden'])
405    rewrite_options = config_proto.graph_options.rewrite_options
406    rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed')
407
408    new_config = strategy.update_config_proto(config_proto)
409
410    # Verify group leader
411    self.assertEqual('/job:worker/replica:0/task:0',
412                     new_config.experimental.collective_group_leader)
413
414    # Verify device filters.
415    self.assertEqual(['/job:worker/task:1'], new_config.device_filters)
416
417    # Verify rewrite options.
418    new_rewrite_options = new_config.graph_options.rewrite_options
419    self.assertEqual(rewriter_config_pb2.RewriterConfig.ON,
420                     new_rewrite_options.scoped_allocator_optimization)
421    self.assertEqual(['CollectiveReduce'],
422                     new_rewrite_options.scoped_allocator_opts.enable_op)
423
424
425class DistributedCollectiveAllReduceStrategyTestWithChief(
426    CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
427
428  @classmethod
429  def setUpClass(cls):
430    """Create a local cluster with 3 workers and 1 chief."""
431    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
432        num_workers=3, num_ps=0, has_chief=True)
433
434  @combinations.generate(
435      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
436  def testMinimizeLossGraph(self, required_gpus):
437    self._run_between_graph_clients(self._test_minimize_loss_graph,
438                                    self._cluster_spec, required_gpus)
439
440  @combinations.generate(
441      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
442  def testVariableInitialization(self, required_gpus):
443    self._run_between_graph_clients(
444        self._test_variable_initialization,
445        self._cluster_spec,
446        num_gpus=required_gpus)
447
448
449class SingleWorkerCollectiveAllReduceStrategy(
450    CollectiveAllReduceStrategyTestBase, strategy_test_lib.DistributionTestBase,
451    strategy_test_lib.TwoDeviceDistributionTestBase, parameterized.TestCase):
452
453  @combinations.generate(combinations.combine(mode=['eager']))
454  def testStrategyInitializationError(self):
455    with self.assertRaisesRegex(
456        ValueError,
457        'cluster_resolver and devices cannot be set at the same time'):
458      _ = collective_all_reduce_strategy.CollectiveAllReduceExtended(
459          container_strategy=None,
460          cluster_resolver=multi_worker_test_base.create_in_process_cluster(
461              num_workers=3, num_ps=0),
462          communication_options=collective_util.Options(),
463          devices=['GPU:0', 'GPU:1'])
464
465  @combinations.generate(
466      combinations.combine(
467          mode=['graph', 'eager'],
468          required_gpus=[0, 1, 2],
469          use_devices_arg=[True, False]))
470  def testMinimizeLoss(self, required_gpus, use_devices_arg):
471    # Collective ops doesn't support strategy with one device.
472    if context.executing_eagerly():
473      strategy, _ = self._get_test_object(
474          None, None, required_gpus, use_devices_arg=use_devices_arg)
475      self._test_minimize_loss_eager(strategy)
476    else:
477      self._test_minimize_loss_graph(None, None, required_gpus)
478
479  @combinations.generate(
480      combinations.combine(
481          mode=['eager'], required_gpus=[1, 2], use_devices_arg=[True, False]))
482  def testNumReplicasInSync(self, required_gpus, use_devices_arg):
483    strategy, _ = self._get_test_object(
484        None, None, required_gpus, use_devices_arg=use_devices_arg)
485    self.assertEqual(required_gpus, strategy.num_replicas_in_sync)
486
487  @combinations.generate(
488      combinations.combine(
489          mode=['eager'],
490          required_tpus=[0, 1, 2],
491          use_devices_arg=[True, False]))
492  def testMinimizeLossTPU(self, required_tpus, use_devices_arg):
493    strategy, _ = self._get_test_object(
494        None, None, num_tpus=required_tpus, use_devices_arg=use_devices_arg)
495    self._test_minimize_loss_eager(strategy)
496
497  @combinations.generate(
498      combinations.combine(
499          mode=['graph', 'eager'],
500          required_gpus=[0, 1, 2],
501          use_devices_arg=[True, False]))
502  def testCallAndMergeExceptions(self, required_gpus, use_devices_arg):
503    strategy, _ = self._get_test_object(
504        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
505    self._test_call_and_merge_exceptions(strategy)
506
507  @combinations.generate(
508      combinations.combine(
509          mode=['graph'],
510          required_gpus=2,
511          use_dataset=[True, False],
512          use_devices_arg=[True, False]))
513  def testMakeInputFnIterator(self, required_gpus, use_dataset,
514                              use_devices_arg):
515    if use_dataset:
516      fn = lambda: dataset_ops.Dataset.range(5 * required_gpus)
517    else:
518      def fn():
519        dataset = dataset_ops.Dataset.range(5 * required_gpus)
520        it = dataset_ops.make_one_shot_iterator(dataset)
521        return it.get_next
522
523    expected_values = [
524        range(i, i + required_gpus) for i in range(0, 10, required_gpus)
525    ]
526
527    input_fn = self._input_fn_to_test_input_context(
528        fn,
529        expected_num_replicas_in_sync=required_gpus,
530        expected_num_input_pipelines=1,
531        expected_input_pipeline_id=0)
532    self._test_input_fn_iterator(
533        None,
534        None,
535        required_gpus,
536        input_fn,
537        expected_values,
538        test_reinitialize=use_dataset,
539        ignore_order=not use_dataset)
540
541  @combinations.generate(
542      combinations.combine(
543          mode=['graph', 'eager'],
544          required_gpus=[0, 1, 2],
545          use_devices_arg=[True, False]))
546  def testReduceToCpu(self, required_gpus, use_devices_arg):
547    strategy, _ = self._get_test_object(
548        None, None, required_gpus, use_devices_arg=use_devices_arg)
549    with strategy.scope():
550      result = strategy.extended.call_for_each_replica(_replica_id_f32)
551      reduced = strategy.reduce(reduce_util.ReduceOp.SUM, result, axis=None)
552      expected = sum(range(strategy.num_replicas_in_sync))
553      self.assertEqual(expected, self.evaluate(reduced))
554
555  @combinations.generate(
556      combinations.combine(
557          mode=['graph'], required_gpus=2, use_devices_arg=[True, False]))
558  def testAllReduceSum(self, required_gpus, use_devices_arg):
559    distribution, target = self._get_test_object(
560        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
561    with self.cached_session(target=target):
562      self._test_all_reduce_sum(distribution)
563
564  @combinations.generate(
565      combinations.combine(
566          mode=['graph'], required_gpus=2, use_devices_arg=[True, False]))
567  def testAllReduceSumGradients(self, required_gpus, use_devices_arg):
568    distribution, target = self._get_test_object(
569        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
570    with self.cached_session(target=target):
571      self._test_all_reduce_sum_gradients(distribution)
572
573  @combinations.generate(
574      combinations.combine(
575          mode=['graph'], required_gpus=2, use_devices_arg=[True, False]))
576  def testAllReduceSumGradientTape(self, required_gpus, use_devices_arg):
577    distribution, target = self._get_test_object(
578        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
579    with self.cached_session(target=target):
580      self._test_all_reduce_sum_gradient_tape(distribution)
581
582  @combinations.generate(
583      combinations.combine(
584          mode=['graph'], required_gpus=2, use_devices_arg=[True, False]))
585  def testAllReduceMean(self, required_gpus, use_devices_arg):
586    distribution, target = self._get_test_object(
587        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
588    with self.cached_session(target=target):
589      self._test_all_reduce_mean(distribution)
590
591  @combinations.generate(
592      combinations.combine(
593          mode=['graph'], required_gpus=2, use_devices_arg=[True, False]))
594  def testAllReduceMeanGradients(self, required_gpus, use_devices_arg):
595    distribution, target = self._get_test_object(
596        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
597    with self.cached_session(target=target):
598      self._test_all_reduce_mean_gradients(distribution)
599
600  @combinations.generate(
601      combinations.combine(
602          mode=['graph'], required_gpus=2, use_devices_arg=[True, False]))
603  def testAllReduceMeanGradientTape(self, required_gpus, use_devices_arg):
604    distribution, target = self._get_test_object(
605        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
606    with self.cached_session(target=target):
607      self._test_all_reduce_mean_gradient_tape(distribution)
608
609  @combinations.generate(
610      combinations.combine(
611          mode=['graph'], required_gpus=2, use_devices_arg=[True, False]))
612  def testNumpyDataset(self, required_gpus, use_devices_arg):
613    strategy, target = self._get_test_object(
614        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
615    self._test_numpy_dataset(
616        strategy, session=self.cached_session(target=target))
617
618  @combinations.generate(
619      combinations.combine(
620          mode=['eager'], required_gpus=2, use_devices_arg=[True, False]))
621  def testReplicateDataset(self, required_gpus, use_devices_arg):
622    strategy, _ = self._get_test_object(
623        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
624    dataset_fn = lambda: dataset_ops.Dataset.range(10)
625    expected_values = [[i, i + 1] for i in range(0, 10, 2)]
626    input_fn = self._input_fn_to_test_input_context(
627        dataset_fn,
628        expected_num_replicas_in_sync=required_gpus,
629        expected_num_input_pipelines=1,
630        expected_input_pipeline_id=0)
631    self._test_input_fn_iterable(strategy, input_fn, expected_values)
632
633  @combinations.generate(
634      combinations.combine(mode=['graph'], use_devices_arg=[True, False]))
635  def testDeepCopy(self, use_devices_arg):
636    distribution, _ = self._get_test_object(
637        None, None, use_devices_arg=use_devices_arg)
638    copy.deepcopy(distribution)
639
640  @combinations.generate(
641      combinations.combine(
642          mode=['graph', 'eager'],
643          required_gpus=[0, 1, 2],
644          use_devices_arg=[True, False]))
645  def testSummaryForReplicaZeroOnly(self, required_gpus, use_devices_arg):
646    strategy, target = self._get_test_object(
647        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
648    with self.cached_session(target=target):
649      self._test_summary_for_replica_zero_only(strategy)
650
651  @combinations.generate(
652      combinations.combine(
653          mode=['graph', 'eager'],
654          required_gpus=[0, 1, 2],
655          use_devices_arg=[True, False]))
656  def testTrainableVariables(self, required_gpus, use_devices_arg):
657    strategy, _ = self._get_test_object(
658        None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg)
659    self._test_trainable_variable(strategy)
660
661
662class LogicalDeviceTest(test.TestCase, parameterized.TestCase):
663
664  @combinations.generate(combinations.combine(mode=['eager'], required_gpus=1))
665  def testKeepLogicalDevice(self):
666    gpus = tf_config.list_physical_devices('GPU')
667    if len(gpus) > 1:
668      self.skipTest('Skip logical device test on multi GPUs, since partial GPU '
669                    'virtualization is not permitted.')
670    # Cannot change logical device after the context initialization.
671    context._reset_context()  # pylint: disable=protected-access
672    cluster_spec = multi_worker_test_base.create_cluster_spec(
673        has_chief=False, num_workers=1)
674    resolver = cluster_resolver_lib.SimpleClusterResolver(
675        cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
676        task_type='worker',
677        task_id=0)
678
679    logical_gpus = len(gpus) * 2
680    for i, device in enumerate(gpus):
681      n = (i + 1) * logical_gpus // len(gpus) - i * logical_gpus // len(gpus)
682      assert n > 0  # guaranteed if count >= len(devices)
683      configs = []
684      for ordinal in range(n):
685        config = context.LogicalDeviceConfiguration(
686            memory_limit=64,
687            experimental_device_ordinal=ordinal)
688        configs.append(config)
689
690      tf_config.set_logical_device_configuration(device, configs)
691
692    collective_all_reduce_strategy.CollectiveAllReduceStrategy(
693        cluster_resolver=resolver)
694    # Since we create two logical GPUs out of the last GPU, there should be one
695    # more logical GPUs than physical GPUs.
696    self.assertLen(tf_config.list_logical_devices('GPU'), logical_gpus)
697    context._reset_context()  # pylint: disable=protected-access
698
699
700@combinations.generate(
701    combinations.combine(
702        strategy=[
703            strategy_combinations.multi_worker_mirrored_2x1_cpu,
704            strategy_combinations.multi_worker_mirrored_2x1_gpu,
705            strategy_combinations.multi_worker_mirrored_2x2_gpu,
706            strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
707        ],
708        mode=['eager']))
709class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase):
710
711  def setUp(self):
712    super().setUp()
713    if context.context().list_physical_devices('TPU'):
714      self.skipTest('Test not supported on TPUs')
715
716  def test_replica_id_in_sync_group(self, strategy):
717
718    def replica_fn():
719      replica_ctx = distribution_strategy_context.get_replica_context()
720      return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id
721
722    results = test_util.gather(strategy, strategy.run(replica_fn))
723    self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)),
724                        results[0].numpy())
725    self.assertAllEqual(
726        list(range(len(strategy.extended.worker_devices))) *
727        strategy.extended._num_workers, results[1].numpy())
728
729  def test_deep_copy_not_allowed(self, strategy):
730    # Check health is disabled in tests by default. We need to enable it for
731    # this test to simulate the real world.
732    strategy.extended._start_check_health_thread()
733    try:
734      with self.assertRaisesRegex(ValueError, 'cannot be deep copied'):
735        copy.deepcopy(strategy)
736      with self.assertRaisesRegex(ValueError, 'cannot be deep copied'):
737        with ops.Graph().as_default():
738          copy.deepcopy(strategy)
739    finally:
740      strategy.extended._stop_check_health_thread()
741
742
743class ExperimentalCompatibilityTest(test.TestCase):
744
745  def testIsInstance(self):
746    # It's not uncommon for people to special case MultiWorkerMirroredStrategy,
747    # so we need to make sure isinstance check works for combinations between
748    # the experimental and non-experimental endpoints.
749    strategy = CollectiveAllReduceStrategy()
750    experimental_strategy = _CollectiveAllReduceStrategyExperimental()
751    self.assertIsInstance(strategy, CollectiveAllReduceStrategy)
752    self.assertIsInstance(strategy, _CollectiveAllReduceStrategyExperimental)
753    self.assertIsInstance(experimental_strategy, CollectiveAllReduceStrategy)
754    self.assertIsInstance(experimental_strategy,
755                          _CollectiveAllReduceStrategyExperimental)
756
757  def testName(self):
758    # Estimator checks the __name__ to special case MultiWorkerMirroredStrategy.
759    self.assertEqual(CollectiveAllReduceStrategy.__name__,
760                     'CollectiveAllReduceStrategy')
761    self.assertEqual(_CollectiveAllReduceStrategyExperimental.__name__,
762                     'CollectiveAllReduceStrategy')
763
764
765def _replica_id_f32():
766  return math_ops.cast(
767      distribution_strategy_context.get_replica_context()
768      .replica_id_in_sync_group, dtypes.float32)
769
770
771if __name__ == '__main__':
772  # TODO(b/172304955): enable logical devices.
773  test_util.main(config_logical_devices=False)
774