1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ParameterServerStrategy.""" 16 17import copy 18import threading 19 20from absl.testing import parameterized 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.distribute import central_storage_strategy 24from tensorflow.python.distribute import combinations 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import distribute_utils 28from tensorflow.python.distribute import distribution_strategy_context as ds_context 29from tensorflow.python.distribute import multi_worker_test_base 30from tensorflow.python.distribute import multi_worker_util 31from tensorflow.python.distribute import parameter_server_strategy 32from tensorflow.python.distribute import ps_values 33from tensorflow.python.distribute import reduce_util 34from tensorflow.python.distribute import strategy_test_lib 35from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 36from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 37from tensorflow.python.eager import backprop 38from tensorflow.python.eager import context 39from tensorflow.python.estimator import run_config 40from tensorflow.python.framework import constant_op 41from tensorflow.python.framework import device as tf_device 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import errors 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import tensor_util 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import control_flow_ops 48from tensorflow.python.ops import gradients 49from tensorflow.python.ops import math_ops 50from tensorflow.python.ops import partitioned_variables 51from tensorflow.python.ops import resource_variable_ops 52from tensorflow.python.ops import variable_scope 53from tensorflow.python.ops import variables 54from tensorflow.python.platform import test 55from tensorflow.python.training import training_util 56 57CHIEF = run_config.TaskType.CHIEF 58WORKER = run_config.TaskType.WORKER 59PS = run_config.TaskType.PS 60 61 62def _get_replica_id_integer(): 63 replica_id = ds_context.get_replica_context().replica_id_in_sync_group 64 if isinstance(replica_id, ops.Tensor): 65 replica_id = tensor_util.constant_value(replica_id) 66 return replica_id 67 68 69def create_test_objects(cluster_spec=None, 70 task_type=None, 71 task_id=None, 72 num_gpus=None, 73 sess_config=None): 74 sess_config = sess_config or config_pb2.ConfigProto() 75 if num_gpus is None: 76 num_gpus = context.num_gpus() 77 if cluster_spec and task_type and task_id is not None: 78 cluster_resolver = SimpleClusterResolver( 79 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 80 task_type=task_type, 81 task_id=task_id, 82 num_accelerators={'GPU': num_gpus}) 83 distribution = parameter_server_strategy.ParameterServerStrategyV1( 84 cluster_resolver) 85 target = 'grpc://' + cluster_spec[WORKER][task_id] 86 else: 87 distribution = ( 88 central_storage_strategy.CentralStorageStrategy._from_num_gpus(num_gpus) 89 ) 90 target = '' 91 92 sess_config = copy.deepcopy(sess_config) 93 sess_config = distribution.update_config_proto(sess_config) 94 95 return distribution, target, sess_config 96 97 98class ParameterServerStrategyTestBase( 99 multi_worker_test_base.MultiWorkerTestBase): 100 101 def setUp(self): 102 self._result = 0 103 self._lock = threading.Lock() 104 self._init_condition = threading.Condition() 105 self._init_reached = 0 106 self._finish_condition = threading.Condition() 107 self._finish_reached = 0 108 self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True) 109 super(ParameterServerStrategyTestBase, self).setUp() 110 111 def _get_test_objects(self, task_type, task_id, num_gpus): 112 return create_test_objects( 113 cluster_spec=self._cluster_spec, 114 task_type=task_type, 115 task_id=task_id, 116 num_gpus=num_gpus, 117 sess_config=self._sess_config) 118 119 def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): 120 worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) 121 d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) 122 with ops.Graph().as_default(), \ 123 self.cached_session(target=self._default_target, 124 config=sess_config) as sess, \ 125 d.scope(): 126 127 # Define a variable outside the call_for_each_replica scope. 128 n = variable_scope.get_variable('n', initializer=10.0) 129 self.assertEqual(n.device, '/job:ps/task:0') 130 131 def model_fn(): 132 if num_gpus == 0: 133 last_part_device = 'device:CPU:0' 134 else: 135 replica_id = _get_replica_id_integer() 136 last_part_device = ('device:GPU:%d' % replica_id) 137 138 a = constant_op.constant(1.0) 139 b = constant_op.constant(2.0) 140 c = a + b 141 self.assertEqual(a.device, worker_device + '/' + last_part_device) 142 self.assertEqual(b.device, worker_device + '/' + last_part_device) 143 self.assertEqual(c.device, worker_device + '/' + last_part_device) 144 145 # The device scope is ignored for variables but not for normal ops. 146 with ops.device('/job:worker/task:0'): 147 x = variable_scope.get_variable( 148 'x', initializer=10.0, 149 aggregation=variable_scope.VariableAggregation.SUM) 150 x_add = x.assign_add(c) 151 e = a + c 152 # The variable x is on the task 1 since the device_function has been 153 # called once before the model_fn. 154 self.assertEqual(x.device, '/job:ps/task:1') 155 self.assertEqual(x_add.device, x.device) 156 self.assertEqual(e.device, 157 '/job:worker/replica:0/task:0/%s' % last_part_device) 158 159 # The colocate_vars_with can override the distribution's device. 160 with d.extended.colocate_vars_with(x): 161 y = variable_scope.get_variable( 162 'y', initializer=20.0, 163 aggregation=variable_scope.VariableAggregation.SUM) 164 # We add an identity here to avoid complaints about summing 165 # non-distributed values. 166 y_add = y.assign_add(array_ops.identity(x_add)) 167 self.assertEqual(y.device, '/job:ps/task:1') 168 self.assertEqual(y_add.device, y.device) 169 self.assertEqual(y.device, x.device) 170 171 z = variable_scope.get_variable( 172 'z', initializer=10.0, 173 aggregation=variable_scope.VariableAggregation.SUM) 174 self.assertEqual(z.device, '/job:ps/task:0') 175 self.assertNotEqual(z.device, x.device) 176 177 with ops.control_dependencies([y_add]): 178 # We add an identity here to avoid complaints about summing 179 # non-distributed values. 180 z_add = z.assign_add(array_ops.identity(y)) 181 with ops.control_dependencies([z_add]): 182 f = z + c 183 self.assertEqual(f.device, worker_device + '/' + last_part_device) 184 185 # The device scope would merge with the default worker device. 186 with ops.device('/CPU:1'): 187 g = e + 1.0 188 self.assertEqual(g.device, worker_device + '/device:CPU:1') 189 190 # This ops.colocate_with will be ignored when defining a variable but not 191 # for a normal tensor. 192 with ops.colocate_with(x): 193 u = variable_scope.get_variable('u', initializer=30.0) 194 v = variable_scope.get_variable('v', initializer=30.0) 195 h = f + 1.0 196 self.assertIn('/job:ps/', u.device) 197 self.assertIn('/job:ps/', v.device) 198 # u and v are on different parameter servers. 199 self.assertTrue(u.device != x.device or v.device != x.device) 200 self.assertTrue(u.device == x.device or v.device == x.device) 201 # Here h is not on one worker. Note h.device is canonical while x.device 202 # is not but. 203 self.assertIn('/job:ps/', h.device) 204 return y_add, z_add, f 205 206 y, z, f = d.extended.call_for_each_replica(model_fn) 207 self.assertNotEqual(y, None) 208 self.assertNotEqual(z, None) 209 self.assertNotEqual(f, None) 210 211 if context.num_gpus() >= 1 and num_gpus <= 1: 212 self.evaluate(variables.global_variables_initializer()) 213 y_val, z_val, f_val = sess.run([y, z, f]) 214 self.assertEqual(y_val, 33.0) 215 self.assertEqual(z_val, 43.0) 216 self.assertEqual(f_val, 46.0) 217 218 def _test_device_assignment_distributed_enable_partitioner( 219 self, task_type, task_id, num_gpus): 220 d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) 221 num_shards = len(d.extended.parameter_devices) 222 partitioner = partitioned_variables.fixed_size_partitioner(num_shards) 223 with ops.Graph().as_default(), \ 224 self.cached_session(target=self._default_target, 225 config=sess_config) as sess, \ 226 d.scope(): 227 228 n = variable_scope.get_variable( 229 'n', 230 initializer=constant_op.constant([10.0, 20.0]), 231 aggregation=variable_scope.VariableAggregation.SUM, 232 partitioner=partitioner) 233 234 for part_id, var in enumerate(n): 235 self.assertEqual(var.device, '/job:ps/task:%d' % part_id) 236 237 def model_fn(): 238 a = constant_op.constant([3.0, 5.0]) 239 # The device scope is ignored for variables but not for normal ops. 240 with ops.device('/job:worker/task:0'): 241 x = variable_scope.get_variable( 242 'x', 243 initializer=constant_op.constant([10.0, 20.0]), 244 aggregation=variable_scope.VariableAggregation.SUM, 245 partitioner=partitioner) 246 x_add = x.assign_add(a, name='x_add') 247 # The variable x is on the task 1 since the device_function has been 248 # called once before the model_fn. 249 for part_id, var in enumerate(x): 250 self.assertEqual(var.device, '/job:ps/task:%d' % part_id) 251 self.assertEqual(var.device, x_add[part_id].device) 252 253 return x_add 254 255 x = d.extended.call_for_each_replica(model_fn) 256 257 if context.num_gpus() >= 1: 258 self.evaluate(variables.global_variables_initializer()) 259 x_val = sess.run(x) 260 if num_gpus < 1: 261 self.assertEqual(x_val, [13.0, 25.0]) 262 else: 263 x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] 264 self.assertEqual(x_val, x_expect) 265 266 def _test_device_assignment_local(self, 267 d, 268 compute_device='CPU', 269 variable_device='CPU', 270 num_gpus=0): 271 with ops.Graph().as_default(), \ 272 self.cached_session(target=self._default_target, 273 config=self._sess_config) as sess, \ 274 d.scope(): 275 276 def model_fn(): 277 if 'CPU' in compute_device: 278 replica_compute_device = '/device:CPU:0' 279 else: 280 replica_id = _get_replica_id_integer() 281 replica_compute_device = ('/device:GPU:%d' % replica_id) 282 replica_compute_device = device_util.canonicalize( 283 replica_compute_device) 284 285 if 'CPU' in variable_device: 286 replica_variable_device = '/device:CPU:0' 287 else: 288 replica_id = _get_replica_id_integer() 289 replica_variable_device = ('/device:GPU:%d' % replica_id) 290 replica_variable_device = device_util.canonicalize( 291 replica_variable_device) 292 293 a = constant_op.constant(1.0) 294 b = constant_op.constant(2.0) 295 c = a + b 296 self.assertEqual(a.device, replica_compute_device) 297 self.assertEqual(b.device, replica_compute_device) 298 self.assertEqual(c.device, replica_compute_device) 299 300 # The device scope is ignored for variables but not for normal ops. 301 with ops.device('/device:GPU:2'): 302 x = variable_scope.get_variable( 303 'x', initializer=10.0, 304 aggregation=variable_scope.VariableAggregation.SUM) 305 x_add = x.assign_add(c) 306 e = a + c 307 self.assertEqual( 308 device_util.canonicalize(x.device), replica_variable_device) 309 self.assertEqual(x_add.device, x.device) 310 self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) 311 312 # The colocate_vars_with can override the distribution's device. 313 with d.extended.colocate_vars_with(x): 314 y = variable_scope.get_variable( 315 'y', initializer=20.0, 316 aggregation=variable_scope.VariableAggregation.SUM) 317 # We add an identity here to avoid complaints about summing 318 # non-distributed values. 319 y_add = y.assign_add(array_ops.identity(x_add)) 320 self.assertEqual( 321 device_util.canonicalize(y.device), replica_variable_device) 322 self.assertEqual(y_add.device, y.device) 323 self.assertEqual(y.device, x.device) 324 325 z = variable_scope.get_variable( 326 'z', initializer=10.0, 327 aggregation=variable_scope.VariableAggregation.SUM) 328 self.assertEqual( 329 device_util.canonicalize(z.device), replica_variable_device) 330 331 with ops.control_dependencies([y_add]): 332 # We add an identity here to avoid complaints about summing 333 # non-distributed values. 334 z_add = z.assign_add(array_ops.identity(y)) 335 with ops.control_dependencies([z_add]): 336 f = z + c 337 self.assertEqual(f.device, replica_compute_device) 338 339 # The device scope would merge with the default worker device. 340 with ops.device('/CPU:1'): 341 g = e + 1.0 342 self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1')) 343 344 # This ops.colocate_with will be ignored when defining a variable but not 345 # for a normal tensor. 346 with ops.colocate_with(x): 347 u = variable_scope.get_variable('u', initializer=30.0) 348 h = f + 1.0 349 self.assertEqual( 350 device_util.canonicalize(u.device), replica_variable_device) 351 self.assertEqual( 352 device_util.canonicalize(x.device), 353 device_util.canonicalize(h.device)) 354 return y_add, z_add, f 355 356 y, z, f = d.extended.call_for_each_replica(model_fn) 357 self.assertNotEqual(y, None) 358 self.assertNotEqual(z, None) 359 self.assertNotEqual(f, None) 360 361 if context.num_gpus() >= 1 and num_gpus <= 1: 362 self.evaluate(variables.global_variables_initializer()) 363 y_val, z_val, f_val = sess.run([y, z, f]) 364 self.assertEqual(y_val, 33.0) 365 self.assertEqual(z_val, 43.0) 366 self.assertEqual(f_val, 46.0) 367 368 def _test_simple_increment(self, task_type, task_id, num_gpus): 369 d, master_target, sess_config = self._get_test_objects( 370 task_type, task_id, num_gpus) 371 if d.extended._cluster_spec: 372 num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) 373 if 'chief' in d.extended._cluster_spec.as_dict(): 374 num_workers += 1 375 else: 376 num_workers = 1 377 with ops.Graph().as_default(), \ 378 self.cached_session(target=master_target, 379 config=sess_config) as sess, \ 380 d.scope(): 381 382 def model_fn(): 383 x = variable_scope.get_variable( 384 'x', initializer=10.0, 385 aggregation=variable_scope.VariableAggregation.SUM) 386 y = variable_scope.get_variable( 387 'y', initializer=20.0, 388 aggregation=variable_scope.VariableAggregation.SUM) 389 z = variable_scope.get_variable( 390 'z', initializer=30.0, 391 aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) 392 393 # We explicitly make a constant tensor here to avoid complaints about 394 # summing non-distributed values. 395 one = constant_op.constant(1.0) 396 x_add = x.assign_add(one, use_locking=True) 397 y_add = y.assign_add(one, use_locking=True) 398 z_add = z.assign_add(one, use_locking=True) 399 400 train_op = control_flow_ops.group(x_add, y_add, z_add) 401 return x, y, z, train_op 402 403 x, y, z, train_op = d.extended.call_for_each_replica(model_fn) 404 train_op = d.group(train_op) 405 406 if task_id == 0: 407 self.evaluate(variables.global_variables_initializer()) 408 409 # Workers waiting for chief worker's initializing variables. 410 self._init_condition.acquire() 411 self._init_reached += 1 412 while self._init_reached != num_workers: 413 self._init_condition.wait() 414 self._init_condition.notify_all() 415 self._init_condition.release() 416 417 sess.run(train_op) 418 419 # Wait for other workers to finish training. 420 self._finish_condition.acquire() 421 self._finish_reached += 1 422 while self._finish_reached != num_workers: 423 self._finish_condition.wait() 424 self._finish_condition.notify_all() 425 self._finish_condition.release() 426 427 x_val, y_val, z_val = sess.run([x, y, z]) 428 self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync) 429 self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync) 430 self.assertEqual(z_val, 30.0 + 1.0 * num_workers) 431 432 def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): 433 d, master_target, sess_config = self._get_test_objects( 434 task_type, task_id, num_gpus) 435 if task_type: 436 # Multi-worker 437 assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec 438 num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) 439 if CHIEF in d.extended._cluster_spec.as_dict(): 440 num_workers += 1 441 else: 442 # local 443 num_workers = 1 444 445 with ops.Graph().as_default(), \ 446 self.cached_session(target=master_target, 447 config=sess_config) as sess, \ 448 d.scope(): 449 kernel = strategy_test_lib.create_variable_like_keras_layer( 450 'kernel', (1, 1), dtypes.float32,) 451 452 def loss_fn(x): 453 y = array_ops.reshape( 454 math_ops.matmul(x, kernel), []) - constant_op.constant(1.) 455 return y * y 456 457 # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for 458 # multiple graphs (b/111216820). 459 def grad_fn(x): 460 loss = loss_fn(x) 461 var_list = ( 462 variables.trainable_variables() + ops.get_collection( 463 ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 464 grads = gradients.gradients(loss, var_list) 465 ret = list(zip(grads, var_list)) 466 return ret 467 468 def update(v, g): 469 return v.assign_sub(0.05 * g, use_locking=True) 470 471 one = constant_op.constant([[1.]]) 472 473 def step(): 474 """Perform one optimization step.""" 475 # Run forward & backward to get gradients, variables list. 476 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) 477 # Update the variables using the gradients and the update() function. 478 before_list = [] 479 after_list = [] 480 for g, v in g_v: 481 fetched = d.extended.read_var(v) 482 before_list.append(fetched) 483 with ops.control_dependencies([fetched]): 484 # TODO(yuefengz): support non-Mirrored variable as destinations. 485 g = d.extended.reduce_to( 486 reduce_util.ReduceOp.SUM, g, destinations=v) 487 with ops.control_dependencies( 488 d.extended.update(v, update, args=(g,), group=False)): 489 after_list.append(d.extended.read_var(v)) 490 return before_list, after_list 491 492 before_out, after_out = step() 493 494 if (not task_type or 495 multi_worker_util.is_chief( 496 d.extended._cluster_spec, task_type, task_id)): 497 self.evaluate(variables.global_variables_initializer()) 498 499 # Workers waiting for chief worker's initializing variables. 500 self._init_condition.acquire() 501 self._init_reached += 1 502 while self._init_reached != num_workers: 503 self._init_condition.wait() 504 self._init_condition.notify_all() 505 self._init_condition.release() 506 507 for i in range(10): 508 b, a = sess.run((before_out, after_out)) 509 if i == 0: 510 before, = b 511 after, = a 512 513 error_before = abs(before - 1) 514 error_after = abs(after - 1) 515 # Error should go down 516 self.assertLess(error_after, error_before) 517 518 def _test_input_fn_iterator(self, 519 task_type, 520 task_id, 521 num_gpus, 522 input_fn, 523 expected_values, 524 test_reinitialize=True, 525 ignore_order=False): 526 distribution, master_target, config = self._get_test_objects( 527 task_type, task_id, num_gpus) 528 devices = distribution.extended.worker_devices 529 530 with ops.Graph().as_default(), \ 531 self.cached_session(config=config, 532 target=master_target) as sess: 533 iterator = distribution.make_input_fn_iterator(input_fn) 534 sess.run(iterator.initializer) 535 536 for expected_value in expected_values: 537 next_element = iterator.get_next() 538 computed_value = sess.run([distribute_utils.select_replica( 539 r, next_element) for r in range(len(devices))]) 540 if ignore_order: 541 self.assertCountEqual(expected_value, computed_value) 542 else: 543 self.assertEqual(expected_value, computed_value) 544 545 with self.assertRaises(errors.OutOfRangeError): 546 next_element = iterator.get_next() 547 sess.run([distribute_utils.select_replica(r, next_element) 548 for r in range(len(devices))]) 549 550 # After re-initializing the iterator, should be able to iterate again. 551 if test_reinitialize: 552 sess.run(iterator.initializer) 553 554 for expected_value in expected_values: 555 next_element = iterator.get_next() 556 computed_value = sess.run([distribute_utils.select_replica( 557 r, next_element) for r in range(len(devices))]) 558 if ignore_order: 559 self.assertCountEqual(expected_value, computed_value) 560 else: 561 self.assertEqual(expected_value, computed_value) 562 563 564class ParameterServerStrategyTest( 565 ParameterServerStrategyTestBase, 566 strategy_test_lib.DistributionTestBase, 567 strategy_test_lib.TwoDeviceDistributionTestBase, 568 parameterized.TestCase): 569 570 @classmethod 571 def setUpClass(cls): 572 cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( 573 num_workers=3, num_ps=2) 574 cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] 575 576 @combinations.generate(combinations.combine(mode=['graph'])) 577 def test_num_replicas_in_sync(self): 578 strategy, _, _ = create_test_objects(num_gpus=2) 579 # All the devices on a given worker are in sync which in this case is the 580 # number of gpus on each worker. 581 self.assertEqual(2, strategy.num_replicas_in_sync) 582 583 @combinations.generate(combinations.combine(mode=['graph'])) 584 def testDeviceAssignmentLocalCPU(self): 585 strategy, _, _ = create_test_objects(num_gpus=0) 586 self._test_device_assignment_local( 587 strategy, compute_device='CPU', variable_device='CPU', num_gpus=0) 588 589 @combinations.generate(combinations.combine(mode=['graph'])) 590 def testDeviceAssignmentLocalOneGPU(self): 591 strategy, _, _ = create_test_objects(num_gpus=1) 592 self._test_device_assignment_local( 593 strategy, compute_device='GPU', variable_device='GPU', num_gpus=1) 594 595 @combinations.generate(combinations.combine(mode=['graph'])) 596 def testDeviceAssignmentLocalTwoGPUs(self): 597 strategy, _, _ = create_test_objects(num_gpus=2) 598 self._test_device_assignment_local( 599 strategy, compute_device='GPU', variable_device='CPU', num_gpus=2) 600 601 @combinations.generate( 602 combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) 603 def testDeviceAssignmentDistributed(self, num_gpus): 604 self._test_device_assignment_distributed('worker', 1, num_gpus) 605 606 @combinations.generate( 607 combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) 608 def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): 609 self._test_device_assignment_distributed_enable_partitioner( 610 'worker', 1, num_gpus) 611 612 @combinations.generate(combinations.combine(mode=['graph'])) 613 def testSimpleBetweenGraph(self): 614 self._run_between_graph_clients(self._test_simple_increment, 615 self._cluster_spec, context.num_gpus()) 616 617 @combinations.generate( 618 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 619 def testLocalSimpleIncrement(self, required_gpus): 620 self._test_simple_increment(None, 0, required_gpus) 621 622 @combinations.generate( 623 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 624 def testMinimizeLossGraphDistributed(self, required_gpus): 625 self._run_between_graph_clients(self._test_minimize_loss_graph, 626 self._cluster_spec, required_gpus) 627 628 @combinations.generate( 629 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 630 def testMinimizeLossGraphLocal(self, required_gpus): 631 self._test_minimize_loss_graph(None, None, required_gpus) 632 633 # TODO(priyag): Refactor this and other multi worker tests. 634 @combinations.generate( 635 combinations.combine( 636 mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False])) 637 def testMakeInputFnIteratorDistributed(self, required_gpus, use_dataset): 638 if use_dataset: 639 fn = lambda: dataset_ops.Dataset.range(100) 640 else: 641 def fn(): 642 dataset = dataset_ops.Dataset.range(100) 643 it = dataset_ops.make_one_shot_iterator(dataset) 644 return it.get_next 645 646 expected_values = [[i + j 647 for j in range(required_gpus)] 648 for i in range(0, 100, required_gpus)] 649 650 input_fn = self._input_fn_to_test_input_context( 651 fn, 652 expected_num_replicas_in_sync=required_gpus, 653 expected_num_input_pipelines=3, 654 expected_input_pipeline_id=1) # because task_id = 1 655 self._test_input_fn_iterator( 656 'worker', 657 1, 658 required_gpus, 659 input_fn, 660 expected_values, 661 test_reinitialize=use_dataset, 662 ignore_order=not use_dataset) 663 664 @combinations.generate( 665 combinations.combine( 666 mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False])) 667 def testMakeInputFnIteratorLocal(self, required_gpus, use_dataset): 668 if use_dataset: 669 fn = lambda: dataset_ops.Dataset.range(100) 670 else: 671 672 def fn(): 673 dataset = dataset_ops.Dataset.range(100) 674 it = dataset_ops.make_one_shot_iterator(dataset) 675 return it.get_next 676 677 expected_values = [[i + j 678 for j in range(required_gpus)] 679 for i in range(0, 100, required_gpus)] 680 681 input_fn = self._input_fn_to_test_input_context( 682 fn, 683 expected_num_replicas_in_sync=required_gpus, 684 expected_num_input_pipelines=1, 685 expected_input_pipeline_id=0) # only one worker and pipeline for local. 686 self._test_input_fn_iterator( 687 None, 688 None, 689 required_gpus, 690 input_fn, 691 expected_values, 692 test_reinitialize=use_dataset, 693 ignore_order=not use_dataset) 694 695 @combinations.generate(combinations.combine(mode=['graph'])) 696 def testGlobalStepUpdate(self): 697 strategy, _, _ = create_test_objects() 698 self._test_global_step_update(strategy) 699 700 @combinations.generate(combinations.combine(mode=['graph'])) 701 def testUpdateConfigProtoMultiWorker(self): 702 strategy, _, _ = create_test_objects( 703 cluster_spec=self._cluster_spec, 704 task_type='worker', 705 task_id=1, 706 num_gpus=2) 707 708 config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) 709 710 new_config = strategy.update_config_proto(config_proto) 711 712 # Verify device filters. 713 self.assertEqual(['/job:worker/task:1', '/job:ps'], 714 new_config.device_filters) 715 716 # Verify isolate_session_state 717 self.assertFalse(new_config.isolate_session_state) 718 719 @combinations.generate(combinations.combine(mode=['graph'])) 720 def testUpdateConfigProtoLocal(self): 721 strategy, _, _ = create_test_objects(num_gpus=2) 722 723 config_proto = config_pb2.ConfigProto() 724 new_config = strategy.update_config_proto(config_proto) 725 726 # Verify isolate_session_state 727 self.assertTrue(new_config.isolate_session_state) 728 729 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 730 def testInMultiWorkerMode(self): 731 strategy, _, _ = create_test_objects( 732 cluster_spec=self._cluster_spec, 733 task_type='worker', 734 task_id=1, 735 num_gpus=0) 736 self.assertTrue(strategy.extended._in_multi_worker_mode()) 737 738 @combinations.generate(combinations.combine(mode=['eager'])) 739 def testEagerCustomTrainingUnimplementedError(self): 740 cluster_spec = multi_worker_test_base.create_in_process_cluster( 741 num_workers=3, num_ps=2) 742 cluster_resolver = SimpleClusterResolver( 743 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 744 task_type='worker', 745 task_id=1, 746 num_accelerators={'GPU': 0}) 747 strategy = parameter_server_strategy.ParameterServerStrategyV1( 748 cluster_resolver) 749 dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.]) 750 751 def train_step(data): 752 return math_ops.square(data) 753 754 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 755 strategy.experimental_distribute_dataset, 756 dataset.batch(2)) 757 758 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 759 strategy.distribute_datasets_from_function, 760 lambda _: dataset) 761 762 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 763 strategy.scope) 764 765 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 766 strategy.run, train_step) 767 768 @combinations.generate(combinations.combine( 769 mode=['graph'], 770 prefetch_to_device=[None, True])) 771 def test_prefetch_to_device_dataset(self, prefetch_to_device): 772 distribution, _, _ = create_test_objects( 773 cluster_spec=self._cluster_spec, 774 task_type='worker', 775 task_id=0, 776 num_gpus=2) 777 if prefetch_to_device is None: 778 input_options = None 779 else: 780 input_options = distribute_lib.InputOptions( 781 experimental_fetch_to_device=prefetch_to_device) 782 dataset = dataset_ops.Dataset.range(100) 783 dataset = dataset.batch(distribution.num_replicas_in_sync) 784 dataset = distribution.experimental_distribute_dataset( # pylint: disable=assignment-from-no-return 785 dataset, 786 options=input_options) 787 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 788 item = dataset.make_initializable_iterator().get_next() 789 else: 790 self.skipTest('unsupported test combination') 791 device_types = { 792 tf_device.DeviceSpec.from_string(tensor.device).device_type for 793 tensor in item.values} 794 self.assertAllEqual(list(device_types), ['GPU']) 795 796 @combinations.generate(combinations.combine(mode=['graph'])) 797 def test_prefetch_to_host_dataset(self): 798 distribution, _, _ = create_test_objects( 799 cluster_spec=self._cluster_spec, 800 task_type='worker', 801 task_id=0, 802 num_gpus=2) 803 input_options = distribute_lib.InputOptions( 804 experimental_fetch_to_device=False) 805 dataset = dataset_ops.Dataset.range(100) 806 dataset = dataset.batch(distribution.num_replicas_in_sync) 807 dataset = distribution.experimental_distribute_dataset( # pylint: disable=assignment-from-no-return 808 dataset, 809 options=input_options) 810 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 811 item = dataset.make_initializable_iterator().get_next() 812 else: 813 self.skipTest('unsupported test combination') 814 device_types = { 815 tf_device.DeviceSpec.from_string(tensor.device).device_type for 816 tensor in item.values} 817 self.assertAllEqual(list(device_types), ['CPU']) 818 819 820class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, 821 parameterized.TestCase): 822 823 @classmethod 824 def setUpClass(cls): 825 cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( 826 num_workers=3, num_ps=2, has_chief=True) 827 cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] 828 829 @combinations.generate( 830 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 831 def testSimpleBetweenGraph(self, required_gpus): 832 self._run_between_graph_clients(self._test_simple_increment, 833 self._cluster_spec, required_gpus) 834 835 @combinations.generate( 836 combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) 837 def testMinimizeLossGraph(self, num_gpus): 838 self._run_between_graph_clients(self._test_minimize_loss_graph, 839 self._cluster_spec, num_gpus) 840 841 @combinations.generate(combinations.combine(mode=['graph'])) 842 def testGlobalStepIsWrappedOnTwoGPUs(self): 843 strategy, _, _ = create_test_objects(num_gpus=2) 844 with ops.Graph().as_default(), strategy.scope(): 845 created_step = training_util.create_global_step() 846 get_step = training_util.get_global_step() 847 self.assertEqual(created_step, get_step, 848 msg=('created_step %s type %s vs. get_step %s type %s' % 849 (id(created_step), created_step.__class__.__name__, 850 id(get_step), get_step.__class__.__name__))) 851 self.assertIs(ps_values.AggregatingVariable, type(created_step)) 852 self.assertIs(ps_values.AggregatingVariable, type(get_step)) 853 self.assertIs(strategy, created_step.distribute_strategy) 854 855 @combinations.generate(combinations.combine(mode=['graph'])) 856 def testGlobalStepIsNotWrappedOnOneGPU(self): 857 strategy, _, _ = create_test_objects(num_gpus=1) 858 with ops.Graph().as_default(), strategy.scope(): 859 created_step = training_util.create_global_step() 860 get_step = training_util.get_global_step() 861 self.assertEqual(created_step, get_step, 862 msg=('created_step %s type %s vs. get_step %s type %s' % 863 (id(created_step), created_step.__class__.__name__, 864 id(get_step), get_step.__class__.__name__))) 865 self.assertIs(resource_variable_ops.ResourceVariable, type(created_step)) 866 self.assertIs(resource_variable_ops.ResourceVariable, type(get_step)) 867 # All variables have an _distribute_strategy parameter. Only variable 868 # subclasses in distribution strategy expose it publicly. 869 self.assertFalse(hasattr(strategy, 'distribute_strategy')) 870 self.assertIs(strategy, created_step._distribute_strategy) 871 872 @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) 873 def testValueContainer(self): 874 strategy, _, _ = create_test_objects(num_gpus=2) 875 with ops.Graph().as_default(), strategy.scope(): 876 877 def f(): 878 with backprop.GradientTape() as tape: 879 v = variable_scope.get_variable('v', initializer=10.0) 880 _ = v * v 881 v, = tape.watched_variables() 882 w = strategy.extended.value_container(v) 883 self.assertIs(ps_values.AggregatingVariable, type(w)) 884 885 strategy.extended.call_for_each_replica(f) 886 887 888class CentralStorageStrategyTest(strategy_test_lib.DistributionTestBase, 889 parameterized.TestCase): 890 891 @combinations.generate(combinations.combine(mode=['graph', 'eager'], 892 required_gpus=2)) 893 def testNumpyDataset(self): 894 strategy, _, _ = create_test_objects(num_gpus=2) 895 self._test_numpy_dataset(strategy) 896 897 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 898 def testInMultiWorkerMode(self): 899 strategy, _, _ = create_test_objects(num_gpus=0) 900 self.assertFalse(strategy.extended._in_multi_worker_mode()) 901 902 903if __name__ == '__main__': 904 test.main() 905