xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/parameter_server_strategy_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ParameterServerStrategy."""
16
17import copy
18import threading
19
20from absl.testing import parameterized
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.distribute import central_storage_strategy
24from tensorflow.python.distribute import combinations
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import distribute_utils
28from tensorflow.python.distribute import distribution_strategy_context as ds_context
29from tensorflow.python.distribute import multi_worker_test_base
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.distribute import parameter_server_strategy
32from tensorflow.python.distribute import ps_values
33from tensorflow.python.distribute import reduce_util
34from tensorflow.python.distribute import strategy_test_lib
35from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
36from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
37from tensorflow.python.eager import backprop
38from tensorflow.python.eager import context
39from tensorflow.python.estimator import run_config
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import device as tf_device
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import tensor_util
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import gradients
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import partitioned_variables
51from tensorflow.python.ops import resource_variable_ops
52from tensorflow.python.ops import variable_scope
53from tensorflow.python.ops import variables
54from tensorflow.python.platform import test
55from tensorflow.python.training import training_util
56
57CHIEF = run_config.TaskType.CHIEF
58WORKER = run_config.TaskType.WORKER
59PS = run_config.TaskType.PS
60
61
62def _get_replica_id_integer():
63  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
64  if isinstance(replica_id, ops.Tensor):
65    replica_id = tensor_util.constant_value(replica_id)
66  return replica_id
67
68
69def create_test_objects(cluster_spec=None,
70                        task_type=None,
71                        task_id=None,
72                        num_gpus=None,
73                        sess_config=None):
74  sess_config = sess_config or config_pb2.ConfigProto()
75  if num_gpus is None:
76    num_gpus = context.num_gpus()
77  if cluster_spec and task_type and task_id is not None:
78    cluster_resolver = SimpleClusterResolver(
79        cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
80        task_type=task_type,
81        task_id=task_id,
82        num_accelerators={'GPU': num_gpus})
83    distribution = parameter_server_strategy.ParameterServerStrategyV1(
84        cluster_resolver)
85    target = 'grpc://' + cluster_spec[WORKER][task_id]
86  else:
87    distribution = (
88        central_storage_strategy.CentralStorageStrategy._from_num_gpus(num_gpus)
89    )
90    target = ''
91
92  sess_config = copy.deepcopy(sess_config)
93  sess_config = distribution.update_config_proto(sess_config)
94
95  return distribution, target, sess_config
96
97
98class ParameterServerStrategyTestBase(
99    multi_worker_test_base.MultiWorkerTestBase):
100
101  def setUp(self):
102    self._result = 0
103    self._lock = threading.Lock()
104    self._init_condition = threading.Condition()
105    self._init_reached = 0
106    self._finish_condition = threading.Condition()
107    self._finish_reached = 0
108    self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True)
109    super(ParameterServerStrategyTestBase, self).setUp()
110
111  def _get_test_objects(self, task_type, task_id, num_gpus):
112    return create_test_objects(
113        cluster_spec=self._cluster_spec,
114        task_type=task_type,
115        task_id=task_id,
116        num_gpus=num_gpus,
117        sess_config=self._sess_config)
118
119  def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
120    worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
121    d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus)
122    with ops.Graph().as_default(), \
123         self.cached_session(target=self._default_target,
124                             config=sess_config) as sess, \
125         d.scope():
126
127      # Define a variable outside the call_for_each_replica scope.
128      n = variable_scope.get_variable('n', initializer=10.0)
129      self.assertEqual(n.device, '/job:ps/task:0')
130
131      def model_fn():
132        if num_gpus == 0:
133          last_part_device = 'device:CPU:0'
134        else:
135          replica_id = _get_replica_id_integer()
136          last_part_device = ('device:GPU:%d' % replica_id)
137
138        a = constant_op.constant(1.0)
139        b = constant_op.constant(2.0)
140        c = a + b
141        self.assertEqual(a.device, worker_device + '/' + last_part_device)
142        self.assertEqual(b.device, worker_device + '/' + last_part_device)
143        self.assertEqual(c.device, worker_device + '/' + last_part_device)
144
145        # The device scope is ignored for variables but not for normal ops.
146        with ops.device('/job:worker/task:0'):
147          x = variable_scope.get_variable(
148              'x', initializer=10.0,
149              aggregation=variable_scope.VariableAggregation.SUM)
150          x_add = x.assign_add(c)
151          e = a + c
152        # The variable x is on the task 1 since the device_function has been
153        # called once before the model_fn.
154        self.assertEqual(x.device, '/job:ps/task:1')
155        self.assertEqual(x_add.device, x.device)
156        self.assertEqual(e.device,
157                         '/job:worker/replica:0/task:0/%s' % last_part_device)
158
159        # The colocate_vars_with can override the distribution's device.
160        with d.extended.colocate_vars_with(x):
161          y = variable_scope.get_variable(
162              'y', initializer=20.0,
163              aggregation=variable_scope.VariableAggregation.SUM)
164        # We add an identity here to avoid complaints about summing
165        # non-distributed values.
166        y_add = y.assign_add(array_ops.identity(x_add))
167        self.assertEqual(y.device, '/job:ps/task:1')
168        self.assertEqual(y_add.device, y.device)
169        self.assertEqual(y.device, x.device)
170
171        z = variable_scope.get_variable(
172            'z', initializer=10.0,
173            aggregation=variable_scope.VariableAggregation.SUM)
174        self.assertEqual(z.device, '/job:ps/task:0')
175        self.assertNotEqual(z.device, x.device)
176
177        with ops.control_dependencies([y_add]):
178          # We add an identity here to avoid complaints about summing
179          # non-distributed values.
180          z_add = z.assign_add(array_ops.identity(y))
181        with ops.control_dependencies([z_add]):
182          f = z + c
183        self.assertEqual(f.device, worker_device + '/' + last_part_device)
184
185        # The device scope would merge with the default worker device.
186        with ops.device('/CPU:1'):
187          g = e + 1.0
188        self.assertEqual(g.device, worker_device + '/device:CPU:1')
189
190        # This ops.colocate_with will be ignored when defining a variable but not
191        # for a normal tensor.
192        with ops.colocate_with(x):
193          u = variable_scope.get_variable('u', initializer=30.0)
194          v = variable_scope.get_variable('v', initializer=30.0)
195          h = f + 1.0
196        self.assertIn('/job:ps/', u.device)
197        self.assertIn('/job:ps/', v.device)
198        # u and v are on different parameter servers.
199        self.assertTrue(u.device != x.device or v.device != x.device)
200        self.assertTrue(u.device == x.device or v.device == x.device)
201        # Here h is not on one worker. Note h.device is canonical while x.device
202        # is not but.
203        self.assertIn('/job:ps/', h.device)
204        return y_add, z_add, f
205
206      y, z, f = d.extended.call_for_each_replica(model_fn)
207      self.assertNotEqual(y, None)
208      self.assertNotEqual(z, None)
209      self.assertNotEqual(f, None)
210
211      if context.num_gpus() >= 1 and num_gpus <= 1:
212        self.evaluate(variables.global_variables_initializer())
213        y_val, z_val, f_val = sess.run([y, z, f])
214        self.assertEqual(y_val, 33.0)
215        self.assertEqual(z_val, 43.0)
216        self.assertEqual(f_val, 46.0)
217
218  def _test_device_assignment_distributed_enable_partitioner(
219      self, task_type, task_id, num_gpus):
220    d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus)
221    num_shards = len(d.extended.parameter_devices)
222    partitioner = partitioned_variables.fixed_size_partitioner(num_shards)
223    with ops.Graph().as_default(), \
224         self.cached_session(target=self._default_target,
225                             config=sess_config) as sess, \
226         d.scope():
227
228      n = variable_scope.get_variable(
229          'n',
230          initializer=constant_op.constant([10.0, 20.0]),
231          aggregation=variable_scope.VariableAggregation.SUM,
232          partitioner=partitioner)
233
234      for part_id, var in enumerate(n):
235        self.assertEqual(var.device, '/job:ps/task:%d' % part_id)
236
237      def model_fn():
238        a = constant_op.constant([3.0, 5.0])
239        # The device scope is ignored for variables but not for normal ops.
240        with ops.device('/job:worker/task:0'):
241          x = variable_scope.get_variable(
242              'x',
243              initializer=constant_op.constant([10.0, 20.0]),
244              aggregation=variable_scope.VariableAggregation.SUM,
245              partitioner=partitioner)
246          x_add = x.assign_add(a, name='x_add')
247        # The variable x is on the task 1 since the device_function has been
248        # called once before the model_fn.
249        for part_id, var in enumerate(x):
250          self.assertEqual(var.device, '/job:ps/task:%d' % part_id)
251          self.assertEqual(var.device, x_add[part_id].device)
252
253        return x_add
254
255      x = d.extended.call_for_each_replica(model_fn)
256
257      if context.num_gpus() >= 1:
258        self.evaluate(variables.global_variables_initializer())
259        x_val = sess.run(x)
260        if num_gpus < 1:
261          self.assertEqual(x_val, [13.0, 25.0])
262        else:
263          x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus]
264          self.assertEqual(x_val, x_expect)
265
266  def _test_device_assignment_local(self,
267                                    d,
268                                    compute_device='CPU',
269                                    variable_device='CPU',
270                                    num_gpus=0):
271    with ops.Graph().as_default(), \
272         self.cached_session(target=self._default_target,
273                             config=self._sess_config) as sess, \
274         d.scope():
275
276      def model_fn():
277        if 'CPU' in compute_device:
278          replica_compute_device = '/device:CPU:0'
279        else:
280          replica_id = _get_replica_id_integer()
281          replica_compute_device = ('/device:GPU:%d' % replica_id)
282        replica_compute_device = device_util.canonicalize(
283            replica_compute_device)
284
285        if 'CPU' in variable_device:
286          replica_variable_device = '/device:CPU:0'
287        else:
288          replica_id = _get_replica_id_integer()
289          replica_variable_device = ('/device:GPU:%d' % replica_id)
290        replica_variable_device = device_util.canonicalize(
291            replica_variable_device)
292
293        a = constant_op.constant(1.0)
294        b = constant_op.constant(2.0)
295        c = a + b
296        self.assertEqual(a.device, replica_compute_device)
297        self.assertEqual(b.device, replica_compute_device)
298        self.assertEqual(c.device, replica_compute_device)
299
300        # The device scope is ignored for variables but not for normal ops.
301        with ops.device('/device:GPU:2'):
302          x = variable_scope.get_variable(
303              'x', initializer=10.0,
304              aggregation=variable_scope.VariableAggregation.SUM)
305          x_add = x.assign_add(c)
306          e = a + c
307        self.assertEqual(
308            device_util.canonicalize(x.device), replica_variable_device)
309        self.assertEqual(x_add.device, x.device)
310        self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2'))
311
312        # The colocate_vars_with can override the distribution's device.
313        with d.extended.colocate_vars_with(x):
314          y = variable_scope.get_variable(
315              'y', initializer=20.0,
316              aggregation=variable_scope.VariableAggregation.SUM)
317        # We add an identity here to avoid complaints about summing
318        # non-distributed values.
319        y_add = y.assign_add(array_ops.identity(x_add))
320        self.assertEqual(
321            device_util.canonicalize(y.device), replica_variable_device)
322        self.assertEqual(y_add.device, y.device)
323        self.assertEqual(y.device, x.device)
324
325        z = variable_scope.get_variable(
326            'z', initializer=10.0,
327            aggregation=variable_scope.VariableAggregation.SUM)
328        self.assertEqual(
329            device_util.canonicalize(z.device), replica_variable_device)
330
331        with ops.control_dependencies([y_add]):
332          # We add an identity here to avoid complaints about summing
333          # non-distributed values.
334          z_add = z.assign_add(array_ops.identity(y))
335        with ops.control_dependencies([z_add]):
336          f = z + c
337        self.assertEqual(f.device, replica_compute_device)
338
339        # The device scope would merge with the default worker device.
340        with ops.device('/CPU:1'):
341          g = e + 1.0
342        self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1'))
343
344        # This ops.colocate_with will be ignored when defining a variable but not
345        # for a normal tensor.
346        with ops.colocate_with(x):
347          u = variable_scope.get_variable('u', initializer=30.0)
348          h = f + 1.0
349        self.assertEqual(
350            device_util.canonicalize(u.device), replica_variable_device)
351        self.assertEqual(
352            device_util.canonicalize(x.device),
353            device_util.canonicalize(h.device))
354        return y_add, z_add, f
355
356      y, z, f = d.extended.call_for_each_replica(model_fn)
357      self.assertNotEqual(y, None)
358      self.assertNotEqual(z, None)
359      self.assertNotEqual(f, None)
360
361      if context.num_gpus() >= 1 and num_gpus <= 1:
362        self.evaluate(variables.global_variables_initializer())
363        y_val, z_val, f_val = sess.run([y, z, f])
364        self.assertEqual(y_val, 33.0)
365        self.assertEqual(z_val, 43.0)
366        self.assertEqual(f_val, 46.0)
367
368  def _test_simple_increment(self, task_type, task_id, num_gpus):
369    d, master_target, sess_config = self._get_test_objects(
370        task_type, task_id, num_gpus)
371    if d.extended._cluster_spec:
372      num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER))
373      if 'chief' in d.extended._cluster_spec.as_dict():
374        num_workers += 1
375    else:
376      num_workers = 1
377    with ops.Graph().as_default(), \
378         self.cached_session(target=master_target,
379                             config=sess_config) as sess, \
380         d.scope():
381
382      def model_fn():
383        x = variable_scope.get_variable(
384            'x', initializer=10.0,
385            aggregation=variable_scope.VariableAggregation.SUM)
386        y = variable_scope.get_variable(
387            'y', initializer=20.0,
388            aggregation=variable_scope.VariableAggregation.SUM)
389        z = variable_scope.get_variable(
390            'z', initializer=30.0,
391            aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA)
392
393        # We explicitly make a constant tensor here to avoid complaints about
394        # summing non-distributed values.
395        one = constant_op.constant(1.0)
396        x_add = x.assign_add(one, use_locking=True)
397        y_add = y.assign_add(one, use_locking=True)
398        z_add = z.assign_add(one, use_locking=True)
399
400        train_op = control_flow_ops.group(x_add, y_add, z_add)
401        return x, y, z, train_op
402
403      x, y, z, train_op = d.extended.call_for_each_replica(model_fn)
404      train_op = d.group(train_op)
405
406      if task_id == 0:
407        self.evaluate(variables.global_variables_initializer())
408
409      # Workers waiting for chief worker's initializing variables.
410      self._init_condition.acquire()
411      self._init_reached += 1
412      while self._init_reached != num_workers:
413        self._init_condition.wait()
414      self._init_condition.notify_all()
415      self._init_condition.release()
416
417      sess.run(train_op)
418
419      # Wait for other workers to finish training.
420      self._finish_condition.acquire()
421      self._finish_reached += 1
422      while self._finish_reached != num_workers:
423        self._finish_condition.wait()
424      self._finish_condition.notify_all()
425      self._finish_condition.release()
426
427      x_val, y_val, z_val = sess.run([x, y, z])
428      self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync)
429      self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync)
430      self.assertEqual(z_val, 30.0 + 1.0 * num_workers)
431
432  def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
433    d, master_target, sess_config = self._get_test_objects(
434        task_type, task_id, num_gpus)
435    if task_type:
436      # Multi-worker
437      assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec
438      num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER))
439      if CHIEF in d.extended._cluster_spec.as_dict():
440        num_workers += 1
441    else:
442      # local
443      num_workers = 1
444
445    with ops.Graph().as_default(), \
446         self.cached_session(target=master_target,
447                             config=sess_config) as sess, \
448         d.scope():
449      kernel = strategy_test_lib.create_variable_like_keras_layer(
450          'kernel', (1, 1), dtypes.float32,)
451
452      def loss_fn(x):
453        y = array_ops.reshape(
454            math_ops.matmul(x, kernel), []) - constant_op.constant(1.)
455        return y * y
456
457      # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for
458      # multiple graphs (b/111216820).
459      def grad_fn(x):
460        loss = loss_fn(x)
461        var_list = (
462            variables.trainable_variables() + ops.get_collection(
463                ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
464        grads = gradients.gradients(loss, var_list)
465        ret = list(zip(grads, var_list))
466        return ret
467
468      def update(v, g):
469        return v.assign_sub(0.05 * g, use_locking=True)
470
471      one = constant_op.constant([[1.]])
472
473      def step():
474        """Perform one optimization step."""
475        # Run forward & backward to get gradients, variables list.
476        g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
477        # Update the variables using the gradients and the update() function.
478        before_list = []
479        after_list = []
480        for g, v in g_v:
481          fetched = d.extended.read_var(v)
482          before_list.append(fetched)
483          with ops.control_dependencies([fetched]):
484            # TODO(yuefengz): support non-Mirrored variable as destinations.
485            g = d.extended.reduce_to(
486                reduce_util.ReduceOp.SUM, g, destinations=v)
487            with ops.control_dependencies(
488                d.extended.update(v, update, args=(g,), group=False)):
489              after_list.append(d.extended.read_var(v))
490        return before_list, after_list
491
492      before_out, after_out = step()
493
494      if (not task_type or
495          multi_worker_util.is_chief(
496              d.extended._cluster_spec, task_type, task_id)):
497        self.evaluate(variables.global_variables_initializer())
498
499      # Workers waiting for chief worker's initializing variables.
500      self._init_condition.acquire()
501      self._init_reached += 1
502      while self._init_reached != num_workers:
503        self._init_condition.wait()
504      self._init_condition.notify_all()
505      self._init_condition.release()
506
507      for i in range(10):
508        b, a = sess.run((before_out, after_out))
509        if i == 0:
510          before, = b
511        after, = a
512
513      error_before = abs(before - 1)
514      error_after = abs(after - 1)
515      # Error should go down
516      self.assertLess(error_after, error_before)
517
518  def _test_input_fn_iterator(self,
519                              task_type,
520                              task_id,
521                              num_gpus,
522                              input_fn,
523                              expected_values,
524                              test_reinitialize=True,
525                              ignore_order=False):
526    distribution, master_target, config = self._get_test_objects(
527        task_type, task_id, num_gpus)
528    devices = distribution.extended.worker_devices
529
530    with ops.Graph().as_default(), \
531         self.cached_session(config=config,
532                             target=master_target) as sess:
533      iterator = distribution.make_input_fn_iterator(input_fn)
534      sess.run(iterator.initializer)
535
536      for expected_value in expected_values:
537        next_element = iterator.get_next()
538        computed_value = sess.run([distribute_utils.select_replica(
539            r, next_element) for r in range(len(devices))])
540        if ignore_order:
541          self.assertCountEqual(expected_value, computed_value)
542        else:
543          self.assertEqual(expected_value, computed_value)
544
545      with self.assertRaises(errors.OutOfRangeError):
546        next_element = iterator.get_next()
547        sess.run([distribute_utils.select_replica(r, next_element)
548                  for r in range(len(devices))])
549
550      # After re-initializing the iterator, should be able to iterate again.
551      if test_reinitialize:
552        sess.run(iterator.initializer)
553
554        for expected_value in expected_values:
555          next_element = iterator.get_next()
556          computed_value = sess.run([distribute_utils.select_replica(
557              r, next_element) for r in range(len(devices))])
558          if ignore_order:
559            self.assertCountEqual(expected_value, computed_value)
560          else:
561            self.assertEqual(expected_value, computed_value)
562
563
564class ParameterServerStrategyTest(
565    ParameterServerStrategyTestBase,
566    strategy_test_lib.DistributionTestBase,
567    strategy_test_lib.TwoDeviceDistributionTestBase,
568    parameterized.TestCase):
569
570  @classmethod
571  def setUpClass(cls):
572    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
573        num_workers=3, num_ps=2)
574    cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0]
575
576  @combinations.generate(combinations.combine(mode=['graph']))
577  def test_num_replicas_in_sync(self):
578    strategy, _, _ = create_test_objects(num_gpus=2)
579    # All the devices on a given worker are in sync which in this case is the
580    # number of gpus on each worker.
581    self.assertEqual(2, strategy.num_replicas_in_sync)
582
583  @combinations.generate(combinations.combine(mode=['graph']))
584  def testDeviceAssignmentLocalCPU(self):
585    strategy, _, _ = create_test_objects(num_gpus=0)
586    self._test_device_assignment_local(
587        strategy, compute_device='CPU', variable_device='CPU', num_gpus=0)
588
589  @combinations.generate(combinations.combine(mode=['graph']))
590  def testDeviceAssignmentLocalOneGPU(self):
591    strategy, _, _ = create_test_objects(num_gpus=1)
592    self._test_device_assignment_local(
593        strategy, compute_device='GPU', variable_device='GPU', num_gpus=1)
594
595  @combinations.generate(combinations.combine(mode=['graph']))
596  def testDeviceAssignmentLocalTwoGPUs(self):
597    strategy, _, _ = create_test_objects(num_gpus=2)
598    self._test_device_assignment_local(
599        strategy, compute_device='GPU', variable_device='CPU', num_gpus=2)
600
601  @combinations.generate(
602      combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
603  def testDeviceAssignmentDistributed(self, num_gpus):
604    self._test_device_assignment_distributed('worker', 1, num_gpus)
605
606  @combinations.generate(
607      combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
608  def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus):
609    self._test_device_assignment_distributed_enable_partitioner(
610        'worker', 1, num_gpus)
611
612  @combinations.generate(combinations.combine(mode=['graph']))
613  def testSimpleBetweenGraph(self):
614    self._run_between_graph_clients(self._test_simple_increment,
615                                    self._cluster_spec, context.num_gpus())
616
617  @combinations.generate(
618      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
619  def testLocalSimpleIncrement(self, required_gpus):
620    self._test_simple_increment(None, 0, required_gpus)
621
622  @combinations.generate(
623      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
624  def testMinimizeLossGraphDistributed(self, required_gpus):
625    self._run_between_graph_clients(self._test_minimize_loss_graph,
626                                    self._cluster_spec, required_gpus)
627
628  @combinations.generate(
629      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
630  def testMinimizeLossGraphLocal(self, required_gpus):
631    self._test_minimize_loss_graph(None, None, required_gpus)
632
633  # TODO(priyag): Refactor this and other multi worker tests.
634  @combinations.generate(
635      combinations.combine(
636          mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False]))
637  def testMakeInputFnIteratorDistributed(self, required_gpus, use_dataset):
638    if use_dataset:
639      fn = lambda: dataset_ops.Dataset.range(100)
640    else:
641      def fn():
642        dataset = dataset_ops.Dataset.range(100)
643        it = dataset_ops.make_one_shot_iterator(dataset)
644        return it.get_next
645
646    expected_values = [[i + j
647                        for j in range(required_gpus)]
648                       for i in range(0, 100, required_gpus)]
649
650    input_fn = self._input_fn_to_test_input_context(
651        fn,
652        expected_num_replicas_in_sync=required_gpus,
653        expected_num_input_pipelines=3,
654        expected_input_pipeline_id=1)  # because task_id = 1
655    self._test_input_fn_iterator(
656        'worker',
657        1,
658        required_gpus,
659        input_fn,
660        expected_values,
661        test_reinitialize=use_dataset,
662        ignore_order=not use_dataset)
663
664  @combinations.generate(
665      combinations.combine(
666          mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False]))
667  def testMakeInputFnIteratorLocal(self, required_gpus, use_dataset):
668    if use_dataset:
669      fn = lambda: dataset_ops.Dataset.range(100)
670    else:
671
672      def fn():
673        dataset = dataset_ops.Dataset.range(100)
674        it = dataset_ops.make_one_shot_iterator(dataset)
675        return it.get_next
676
677    expected_values = [[i + j
678                        for j in range(required_gpus)]
679                       for i in range(0, 100, required_gpus)]
680
681    input_fn = self._input_fn_to_test_input_context(
682        fn,
683        expected_num_replicas_in_sync=required_gpus,
684        expected_num_input_pipelines=1,
685        expected_input_pipeline_id=0)  # only one worker and pipeline for local.
686    self._test_input_fn_iterator(
687        None,
688        None,
689        required_gpus,
690        input_fn,
691        expected_values,
692        test_reinitialize=use_dataset,
693        ignore_order=not use_dataset)
694
695  @combinations.generate(combinations.combine(mode=['graph']))
696  def testGlobalStepUpdate(self):
697    strategy, _, _ = create_test_objects()
698    self._test_global_step_update(strategy)
699
700  @combinations.generate(combinations.combine(mode=['graph']))
701  def testUpdateConfigProtoMultiWorker(self):
702    strategy, _, _ = create_test_objects(
703        cluster_spec=self._cluster_spec,
704        task_type='worker',
705        task_id=1,
706        num_gpus=2)
707
708    config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden'])
709
710    new_config = strategy.update_config_proto(config_proto)
711
712    # Verify device filters.
713    self.assertEqual(['/job:worker/task:1', '/job:ps'],
714                     new_config.device_filters)
715
716    # Verify isolate_session_state
717    self.assertFalse(new_config.isolate_session_state)
718
719  @combinations.generate(combinations.combine(mode=['graph']))
720  def testUpdateConfigProtoLocal(self):
721    strategy, _, _ = create_test_objects(num_gpus=2)
722
723    config_proto = config_pb2.ConfigProto()
724    new_config = strategy.update_config_proto(config_proto)
725
726    # Verify isolate_session_state
727    self.assertTrue(new_config.isolate_session_state)
728
729  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
730  def testInMultiWorkerMode(self):
731    strategy, _, _ = create_test_objects(
732        cluster_spec=self._cluster_spec,
733        task_type='worker',
734        task_id=1,
735        num_gpus=0)
736    self.assertTrue(strategy.extended._in_multi_worker_mode())
737
738  @combinations.generate(combinations.combine(mode=['eager']))
739  def testEagerCustomTrainingUnimplementedError(self):
740    cluster_spec = multi_worker_test_base.create_in_process_cluster(
741        num_workers=3, num_ps=2)
742    cluster_resolver = SimpleClusterResolver(
743        cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
744        task_type='worker',
745        task_id=1,
746        num_accelerators={'GPU': 0})
747    strategy = parameter_server_strategy.ParameterServerStrategyV1(
748        cluster_resolver)
749    dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.])
750
751    def train_step(data):
752      return math_ops.square(data)
753
754    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
755                           strategy.experimental_distribute_dataset,
756                           dataset.batch(2))
757
758    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
759                           strategy.distribute_datasets_from_function,
760                           lambda _: dataset)
761
762    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
763                           strategy.scope)
764
765    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
766                           strategy.run, train_step)
767
768  @combinations.generate(combinations.combine(
769      mode=['graph'],
770      prefetch_to_device=[None, True]))
771  def test_prefetch_to_device_dataset(self, prefetch_to_device):
772    distribution, _, _ = create_test_objects(
773        cluster_spec=self._cluster_spec,
774        task_type='worker',
775        task_id=0,
776        num_gpus=2)
777    if prefetch_to_device is None:
778      input_options = None
779    else:
780      input_options = distribute_lib.InputOptions(
781          experimental_fetch_to_device=prefetch_to_device)
782    dataset = dataset_ops.Dataset.range(100)
783    dataset = dataset.batch(distribution.num_replicas_in_sync)
784    dataset = distribution.experimental_distribute_dataset(  # pylint: disable=assignment-from-no-return
785        dataset,
786        options=input_options)
787    if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
788      item = dataset.make_initializable_iterator().get_next()
789    else:
790      self.skipTest('unsupported test combination')
791    device_types = {
792        tf_device.DeviceSpec.from_string(tensor.device).device_type for
793        tensor in item.values}
794    self.assertAllEqual(list(device_types), ['GPU'])
795
796  @combinations.generate(combinations.combine(mode=['graph']))
797  def test_prefetch_to_host_dataset(self):
798    distribution, _, _ = create_test_objects(
799        cluster_spec=self._cluster_spec,
800        task_type='worker',
801        task_id=0,
802        num_gpus=2)
803    input_options = distribute_lib.InputOptions(
804        experimental_fetch_to_device=False)
805    dataset = dataset_ops.Dataset.range(100)
806    dataset = dataset.batch(distribution.num_replicas_in_sync)
807    dataset = distribution.experimental_distribute_dataset(  # pylint: disable=assignment-from-no-return
808        dataset,
809        options=input_options)
810    if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
811      item = dataset.make_initializable_iterator().get_next()
812    else:
813      self.skipTest('unsupported test combination')
814    device_types = {
815        tf_device.DeviceSpec.from_string(tensor.device).device_type for
816        tensor in item.values}
817    self.assertAllEqual(list(device_types), ['CPU'])
818
819
820class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
821                                           parameterized.TestCase):
822
823  @classmethod
824  def setUpClass(cls):
825    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
826        num_workers=3, num_ps=2, has_chief=True)
827    cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0]
828
829  @combinations.generate(
830      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
831  def testSimpleBetweenGraph(self, required_gpus):
832    self._run_between_graph_clients(self._test_simple_increment,
833                                    self._cluster_spec, required_gpus)
834
835  @combinations.generate(
836      combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
837  def testMinimizeLossGraph(self, num_gpus):
838    self._run_between_graph_clients(self._test_minimize_loss_graph,
839                                    self._cluster_spec, num_gpus)
840
841  @combinations.generate(combinations.combine(mode=['graph']))
842  def testGlobalStepIsWrappedOnTwoGPUs(self):
843    strategy, _, _ = create_test_objects(num_gpus=2)
844    with ops.Graph().as_default(), strategy.scope():
845      created_step = training_util.create_global_step()
846      get_step = training_util.get_global_step()
847      self.assertEqual(created_step, get_step,
848                       msg=('created_step %s type %s vs. get_step %s type %s' %
849                            (id(created_step), created_step.__class__.__name__,
850                             id(get_step), get_step.__class__.__name__)))
851      self.assertIs(ps_values.AggregatingVariable, type(created_step))
852      self.assertIs(ps_values.AggregatingVariable, type(get_step))
853      self.assertIs(strategy, created_step.distribute_strategy)
854
855  @combinations.generate(combinations.combine(mode=['graph']))
856  def testGlobalStepIsNotWrappedOnOneGPU(self):
857    strategy, _, _ = create_test_objects(num_gpus=1)
858    with ops.Graph().as_default(), strategy.scope():
859      created_step = training_util.create_global_step()
860      get_step = training_util.get_global_step()
861      self.assertEqual(created_step, get_step,
862                       msg=('created_step %s type %s vs. get_step %s type %s' %
863                            (id(created_step), created_step.__class__.__name__,
864                             id(get_step), get_step.__class__.__name__)))
865      self.assertIs(resource_variable_ops.ResourceVariable, type(created_step))
866      self.assertIs(resource_variable_ops.ResourceVariable, type(get_step))
867      # All variables have an _distribute_strategy parameter. Only variable
868      # subclasses in distribution strategy expose it publicly.
869      self.assertFalse(hasattr(strategy, 'distribute_strategy'))
870      self.assertIs(strategy, created_step._distribute_strategy)
871
872  @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2))
873  def testValueContainer(self):
874    strategy, _, _ = create_test_objects(num_gpus=2)
875    with ops.Graph().as_default(), strategy.scope():
876
877      def f():
878        with backprop.GradientTape() as tape:
879          v = variable_scope.get_variable('v', initializer=10.0)
880          _ = v * v
881        v, = tape.watched_variables()
882        w = strategy.extended.value_container(v)
883        self.assertIs(ps_values.AggregatingVariable, type(w))
884
885      strategy.extended.call_for_each_replica(f)
886
887
888class CentralStorageStrategyTest(strategy_test_lib.DistributionTestBase,
889                                 parameterized.TestCase):
890
891  @combinations.generate(combinations.combine(mode=['graph', 'eager'],
892                                              required_gpus=2))
893  def testNumpyDataset(self):
894    strategy, _, _ = create_test_objects(num_gpus=2)
895    self._test_numpy_dataset(strategy)
896
897  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
898  def testInMultiWorkerMode(self):
899    strategy, _, _ = create_test_objects(num_gpus=0)
900    self.assertFalse(strategy.extended._in_multi_worker_mode())
901
902
903if __name__ == '__main__':
904  test.main()
905