1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for CollectiveAllReduceStrategy.""" 16 17import copy 18import functools 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.core.protobuf import rewriter_config_pb2 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib 27from tensorflow.python.distribute import collective_all_reduce_strategy 28from tensorflow.python.distribute import collective_util 29from tensorflow.python.distribute import combinations 30from tensorflow.python.distribute import distribute_lib 31from tensorflow.python.distribute import distribute_utils 32from tensorflow.python.distribute import distribution_strategy_context 33from tensorflow.python.distribute import multi_worker_test_base 34from tensorflow.python.distribute import multi_worker_util 35from tensorflow.python.distribute import reduce_util 36from tensorflow.python.distribute import strategy_combinations 37from tensorflow.python.distribute import strategy_test_lib 38from tensorflow.python.distribute import test_util 39from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 40from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 41from tensorflow.python.eager import context 42from tensorflow.python.framework import config as tf_config 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import device as tf_device 45from tensorflow.python.framework import dtypes 46from tensorflow.python.framework import errors 47from tensorflow.python.framework import ops 48from tensorflow.python.ops import array_ops 49from tensorflow.python.ops import gen_math_ops 50from tensorflow.python.ops import gradients 51from tensorflow.python.ops import init_ops 52from tensorflow.python.ops import init_ops_v2 53from tensorflow.python.ops import math_ops 54from tensorflow.python.ops import variable_scope 55from tensorflow.python.ops import variables 56from tensorflow.python.platform import test 57from tensorflow.python.tpu import tpu_strategy_util 58from tensorflow.python.training.server_lib import ClusterSpec 59 60 61CollectiveAllReduceStrategy = ( 62 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 63CollectiveAllReduceExtended = ( 64 collective_all_reduce_strategy.CollectiveAllReduceExtended) 65_CollectiveAllReduceStrategyExperimental = ( 66 collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental) 67 68 69# TODO(b/231630416): Create more tests to cover the case that strategy uses 70# different number of GPUs than the number of physical devices. 71def create_test_objects(cluster_spec=None, 72 task_type=None, 73 task_id=None, 74 num_gpus=None, 75 num_tpus=None): 76 if num_gpus is None: 77 num_gpus = context.num_gpus() 78 if num_tpus is None: 79 num_tpus = context.context().list_physical_devices('TPU') 80 if num_tpus: 81 tpu_strategy_util.initialize_tpu_system() 82 83 if cluster_spec and task_type and task_id is not None: 84 cluster_resolver = SimpleClusterResolver( 85 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 86 task_type=task_type, 87 task_id=task_id, 88 num_accelerators={'GPU': num_gpus, 'TPU': num_tpus}) 89 target = 'grpc://' + cluster_spec[task_type][task_id] 90 else: 91 cluster_resolver = SimpleClusterResolver( 92 ClusterSpec({}), num_accelerators={'GPU': num_gpus, 'TPU': num_tpus}) 93 target = '' 94 95 strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( 96 cluster_resolver=cluster_resolver) 97 98 return strategy, target 99 100 101class CollectiveAllReduceStrategyTestBase( 102 multi_worker_test_base.MultiWorkerTestBase): 103 104 def setUp(self): 105 # We use a different key_base for each test so that collective keys won't be 106 # reused. 107 CollectiveAllReduceStrategy._collective_key_base += 100000 108 super(CollectiveAllReduceStrategyTestBase, self).setUp() 109 110 def _get_test_object(self, 111 task_type, 112 task_id, 113 num_gpus=0, 114 num_tpus=0, 115 use_devices_arg=False): 116 strategy, target = create_test_objects( 117 cluster_spec=self._cluster_spec, 118 task_type=task_type, 119 task_id=task_id, 120 num_gpus=num_gpus, 121 num_tpus=num_tpus) 122 123 if use_devices_arg and num_gpus > 0: 124 devices = ['GPU:%d' % i for i in range(num_gpus)] 125 # Temporary workaround to manually set the `_extended` field before device 126 # initialization is exposed as a public interface. 127 strategy._extended = CollectiveAllReduceExtended( 128 container_strategy=strategy, 129 cluster_resolver=None, 130 communication_options=collective_util.Options(), 131 devices=devices) 132 # Manually set the field since the workaround bypasses the base 133 # contructor, resulting in the absence of this field. 134 strategy._extended._retrace_functions_for_each_device = (num_gpus > 1) 135 136 return strategy, target 137 138 def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): 139 distribution, master_target = self._get_test_object(task_type, task_id, 140 num_gpus) 141 with ops.Graph().as_default(), \ 142 self.cached_session(target=master_target) as sess, \ 143 distribution.scope(): 144 initializer = functools.partial( 145 init_ops_v2.GlorotUniform(), (1, 1), dtype=dtypes.float32) 146 kernel = variables.Variable( 147 initial_value=initializer, 148 name='gpu_%d/kernel' % distribution.extended._num_devices_per_worker, 149 trainable=True) 150 151 def loss_fn(x): 152 y = array_ops.reshape( 153 gen_math_ops.mat_mul(x, kernel), []) - constant_op.constant(1.) 154 return y * y 155 156 # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for 157 # multiple graphs (b/111216820). 158 def grad_fn(x): 159 loss = loss_fn(x) 160 var_list = ( 161 variables.trainable_variables() + ops.get_collection( 162 ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 163 grads = gradients.gradients(loss, var_list) 164 ret = list(zip(grads, var_list)) 165 return ret 166 167 def update(v, g): 168 return v.assign_sub(0.05 * g, use_locking=True) 169 170 one = constant_op.constant([[1.]]) 171 172 def step(): 173 """Perform one optimization step.""" 174 # Run forward & backward to get gradients, variables list. 175 g_v = distribution.extended.call_for_each_replica(grad_fn, args=[one]) 176 # Update the variables using the gradients and the update() function. 177 before_list = [] 178 after_list = [] 179 for g, v in g_v: 180 fetched = distribution.extended.read_var(v) 181 before_list.append(fetched) 182 with ops.control_dependencies([fetched]): 183 # TODO(yuefengz): support non-Mirrored variable as destinations. 184 g = distribution.extended.reduce_to( 185 reduce_util.ReduceOp.SUM, g, destinations=v) 186 with ops.control_dependencies( 187 distribution.extended.update(v, update, args=(g,), 188 group=False)): 189 after_list.append(distribution.extended.read_var(v)) 190 return before_list, after_list 191 192 before_out, after_out = step() 193 194 if (distribution.extended._local_device_type == 'GPU' and 195 context.num_gpus() < distribution.extended._num_devices_per_worker): 196 return True 197 198 sess.run(variables.global_variables_initializer()) 199 200 for i in range(10): 201 b, a = sess.run((before_out, after_out)) 202 if i == 0: 203 before, = b 204 after, = a 205 206 error_before = abs(before - 1) 207 error_after = abs(after - 1) 208 # Error should go down 209 self.assertLess(error_after, error_before) 210 211 def _test_variable_initialization(self, task_type, task_id, num_gpus): 212 distribution, master_target = self._get_test_object(task_type, task_id, 213 num_gpus) 214 with ops.Graph().as_default(), \ 215 self.cached_session(target=master_target) as sess, \ 216 distribution.scope(): 217 218 def model_fn(): 219 x = variable_scope.get_variable( 220 'x', 221 shape=(2, 3), 222 initializer=init_ops.random_uniform_initializer( 223 1.0, 10.0, dtype=dtypes.float32)) 224 return array_ops.identity(x) 225 226 x = distribution.extended.call_for_each_replica(model_fn) 227 reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x, axis=None) 228 x = distribution.experimental_local_results(x)[0] 229 230 sess.run(variables.global_variables_initializer()) 231 232 x_value, reduced_x_value = sess.run([x, reduced_x]) 233 self.assertTrue( 234 np.allclose(x_value, reduced_x_value, atol=1e-5), 235 msg=('x_value = %r, reduced_x_value = %r' % (x_value, 236 reduced_x_value))) 237 238 def _test_input_fn_iterator(self, 239 task_type, 240 task_id, 241 num_gpus, 242 input_fn, 243 expected_values, 244 test_reinitialize=True, 245 ignore_order=False, 246 use_devices_arg=False): 247 distribution, master_target = self._get_test_object( 248 task_type, task_id, num_gpus, use_devices_arg=use_devices_arg) 249 devices = distribution.extended.worker_devices 250 251 with ops.Graph().as_default(), \ 252 self.cached_session(target=master_target) as sess: 253 iterator = distribution.make_input_fn_iterator(input_fn) 254 sess.run(iterator.initializer) 255 256 for expected_value in expected_values: 257 next_element = iterator.get_next() 258 computed_value = sess.run([distribute_utils.select_replica( 259 r, next_element) for r in range(len(devices))]) 260 if ignore_order: 261 self.assertCountEqual(list(expected_value), list(computed_value)) 262 else: 263 self.assertEqual(list(expected_value), list(computed_value)) 264 265 with self.assertRaises(errors.OutOfRangeError): 266 next_element = iterator.get_next() 267 sess.run([distribute_utils.select_replica(r, next_element) 268 for r in range(len(devices))]) 269 270 # After re-initializing the iterator, should be able to iterate again. 271 if test_reinitialize: 272 sess.run(iterator.initializer) 273 274 for expected_value in expected_values: 275 next_element = iterator.get_next() 276 computed_value = sess.run([ 277 distribute_utils.select_replica(r, next_element) 278 for r in range(len(devices))]) 279 if ignore_order: 280 self.assertCountEqual(list(expected_value), list(computed_value)) 281 else: 282 self.assertEqual(list(expected_value), list(computed_value)) 283 284 285class DistributedCollectiveAllReduceStrategyTest( 286 CollectiveAllReduceStrategyTestBase, 287 strategy_test_lib.DistributionTestBase, 288 parameterized.TestCase): 289 290 @classmethod 291 def setUpClass(cls): 292 """Create a local cluster with 3 workers.""" 293 cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( 294 num_workers=3, num_ps=0) 295 296 @combinations.generate(combinations.combine(mode=['graph'])) 297 def test_num_replicas_in_sync(self): 298 distribution, _ = create_test_objects( 299 cluster_spec=self._cluster_spec, 300 task_type='worker', 301 task_id=0, 302 num_gpus=2) 303 num_workers = len(self._cluster_spec.get('chief', []) + 304 self._cluster_spec.get('worker', [])) 305 self.assertEqual(2 * num_workers, 306 distribution.num_replicas_in_sync) 307 308 @combinations.generate(combinations.combine( 309 mode=['graph'], 310 prefetch_to_device=[None, True])) 311 def test_prefetch_to_device_dataset(self, prefetch_to_device): 312 distribution, _ = self._get_test_object( 313 task_type='worker', task_id=0, num_gpus=2) 314 if prefetch_to_device is None: 315 input_options = None 316 else: 317 input_options = distribute_lib.InputOptions( 318 experimental_fetch_to_device=prefetch_to_device) 319 dataset = dataset_ops.Dataset.range(100) 320 dataset = dataset.batch(distribution.num_replicas_in_sync) 321 dataset = distribution.experimental_distribute_dataset( 322 dataset, options=input_options) 323 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 324 item = dataset.make_initializable_iterator().get_next() 325 else: 326 self.skipTest('unsupported test combination') 327 device_types = { 328 tf_device.DeviceSpec.from_string(tensor.device).device_type for 329 tensor in item.values} 330 self.assertAllEqual(list(device_types), ['GPU']) 331 332 @combinations.generate(combinations.combine(mode=['graph'])) 333 def test_prefetch_to_host_dataset(self): 334 distribution, _ = self._get_test_object( 335 task_type='worker', task_id=0, num_gpus=2) 336 input_options = distribute_lib.InputOptions( 337 experimental_fetch_to_device=False) 338 dataset = dataset_ops.Dataset.range(100) 339 dataset = dataset.batch(distribution.num_replicas_in_sync) 340 dataset = distribution.experimental_distribute_dataset( 341 dataset, options=input_options) 342 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 343 item = dataset.make_initializable_iterator().get_next() 344 else: 345 self.skipTest('unsupported test combination') 346 device_types = { 347 tf_device.DeviceSpec.from_string(tensor.device).device_type for 348 tensor in item.values} 349 self.assertAllEqual(list(device_types), ['CPU']) 350 351 @combinations.generate( 352 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 353 def testMinimizeLossGraph(self, required_gpus): 354 self._run_between_graph_clients(self._test_minimize_loss_graph, 355 self._cluster_spec, required_gpus) 356 357 @combinations.generate( 358 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 359 def testVariableInitialization(self, required_gpus): 360 self._run_between_graph_clients( 361 self._test_variable_initialization, 362 self._cluster_spec, 363 num_gpus=required_gpus) 364 365 @combinations.generate( 366 combinations.combine( 367 mode=['graph'], required_gpus=[0, 1, 2], use_dataset=[True, False])) 368 def testMakeInputFnIterator(self, required_gpus, use_dataset): 369 def _worker_fn(task_type, task_id, required_gpus): 370 if use_dataset: 371 fn = lambda: dataset_ops.Dataset.range(20) 372 else: 373 def fn(): 374 dataset = dataset_ops.Dataset.range(20) 375 it = dataset_ops.make_one_shot_iterator(dataset) 376 return it.get_next 377 # We use CPU as the device when required_gpus = 0 378 devices_per_worker = max(1, required_gpus) 379 expected_values = [[i+j for j in range(devices_per_worker)] 380 for i in range(0, 20, devices_per_worker)] 381 382 input_fn = self._input_fn_to_test_input_context( 383 fn, 384 expected_num_replicas_in_sync=3*devices_per_worker, 385 expected_num_input_pipelines=3, 386 expected_input_pipeline_id=task_id) 387 self._test_input_fn_iterator( 388 task_type, 389 task_id, 390 required_gpus, 391 input_fn, 392 expected_values, 393 test_reinitialize=use_dataset, 394 ignore_order=not use_dataset) 395 396 self._run_between_graph_clients(_worker_fn, self._cluster_spec, 397 required_gpus) 398 399 @combinations.generate(combinations.combine(mode=['graph'])) 400 def testUpdateConfigProto(self): 401 strategy, _ = self._get_test_object( 402 task_type='worker', task_id=1, num_gpus=2) 403 404 config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) 405 rewrite_options = config_proto.graph_options.rewrite_options 406 rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') 407 408 new_config = strategy.update_config_proto(config_proto) 409 410 # Verify group leader 411 self.assertEqual('/job:worker/replica:0/task:0', 412 new_config.experimental.collective_group_leader) 413 414 # Verify device filters. 415 self.assertEqual(['/job:worker/task:1'], new_config.device_filters) 416 417 # Verify rewrite options. 418 new_rewrite_options = new_config.graph_options.rewrite_options 419 self.assertEqual(rewriter_config_pb2.RewriterConfig.ON, 420 new_rewrite_options.scoped_allocator_optimization) 421 self.assertEqual(['CollectiveReduce'], 422 new_rewrite_options.scoped_allocator_opts.enable_op) 423 424 425class DistributedCollectiveAllReduceStrategyTestWithChief( 426 CollectiveAllReduceStrategyTestBase, parameterized.TestCase): 427 428 @classmethod 429 def setUpClass(cls): 430 """Create a local cluster with 3 workers and 1 chief.""" 431 cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( 432 num_workers=3, num_ps=0, has_chief=True) 433 434 @combinations.generate( 435 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 436 def testMinimizeLossGraph(self, required_gpus): 437 self._run_between_graph_clients(self._test_minimize_loss_graph, 438 self._cluster_spec, required_gpus) 439 440 @combinations.generate( 441 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 442 def testVariableInitialization(self, required_gpus): 443 self._run_between_graph_clients( 444 self._test_variable_initialization, 445 self._cluster_spec, 446 num_gpus=required_gpus) 447 448 449class SingleWorkerCollectiveAllReduceStrategy( 450 CollectiveAllReduceStrategyTestBase, strategy_test_lib.DistributionTestBase, 451 strategy_test_lib.TwoDeviceDistributionTestBase, parameterized.TestCase): 452 453 @combinations.generate(combinations.combine(mode=['eager'])) 454 def testStrategyInitializationError(self): 455 with self.assertRaisesRegex( 456 ValueError, 457 'cluster_resolver and devices cannot be set at the same time'): 458 _ = collective_all_reduce_strategy.CollectiveAllReduceExtended( 459 container_strategy=None, 460 cluster_resolver=multi_worker_test_base.create_in_process_cluster( 461 num_workers=3, num_ps=0), 462 communication_options=collective_util.Options(), 463 devices=['GPU:0', 'GPU:1']) 464 465 @combinations.generate( 466 combinations.combine( 467 mode=['graph', 'eager'], 468 required_gpus=[0, 1, 2], 469 use_devices_arg=[True, False])) 470 def testMinimizeLoss(self, required_gpus, use_devices_arg): 471 # Collective ops doesn't support strategy with one device. 472 if context.executing_eagerly(): 473 strategy, _ = self._get_test_object( 474 None, None, required_gpus, use_devices_arg=use_devices_arg) 475 self._test_minimize_loss_eager(strategy) 476 else: 477 self._test_minimize_loss_graph(None, None, required_gpus) 478 479 @combinations.generate( 480 combinations.combine( 481 mode=['eager'], required_gpus=[1, 2], use_devices_arg=[True, False])) 482 def testNumReplicasInSync(self, required_gpus, use_devices_arg): 483 strategy, _ = self._get_test_object( 484 None, None, required_gpus, use_devices_arg=use_devices_arg) 485 self.assertEqual(required_gpus, strategy.num_replicas_in_sync) 486 487 @combinations.generate( 488 combinations.combine( 489 mode=['eager'], 490 required_tpus=[0, 1, 2], 491 use_devices_arg=[True, False])) 492 def testMinimizeLossTPU(self, required_tpus, use_devices_arg): 493 strategy, _ = self._get_test_object( 494 None, None, num_tpus=required_tpus, use_devices_arg=use_devices_arg) 495 self._test_minimize_loss_eager(strategy) 496 497 @combinations.generate( 498 combinations.combine( 499 mode=['graph', 'eager'], 500 required_gpus=[0, 1, 2], 501 use_devices_arg=[True, False])) 502 def testCallAndMergeExceptions(self, required_gpus, use_devices_arg): 503 strategy, _ = self._get_test_object( 504 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 505 self._test_call_and_merge_exceptions(strategy) 506 507 @combinations.generate( 508 combinations.combine( 509 mode=['graph'], 510 required_gpus=2, 511 use_dataset=[True, False], 512 use_devices_arg=[True, False])) 513 def testMakeInputFnIterator(self, required_gpus, use_dataset, 514 use_devices_arg): 515 if use_dataset: 516 fn = lambda: dataset_ops.Dataset.range(5 * required_gpus) 517 else: 518 def fn(): 519 dataset = dataset_ops.Dataset.range(5 * required_gpus) 520 it = dataset_ops.make_one_shot_iterator(dataset) 521 return it.get_next 522 523 expected_values = [ 524 range(i, i + required_gpus) for i in range(0, 10, required_gpus) 525 ] 526 527 input_fn = self._input_fn_to_test_input_context( 528 fn, 529 expected_num_replicas_in_sync=required_gpus, 530 expected_num_input_pipelines=1, 531 expected_input_pipeline_id=0) 532 self._test_input_fn_iterator( 533 None, 534 None, 535 required_gpus, 536 input_fn, 537 expected_values, 538 test_reinitialize=use_dataset, 539 ignore_order=not use_dataset) 540 541 @combinations.generate( 542 combinations.combine( 543 mode=['graph', 'eager'], 544 required_gpus=[0, 1, 2], 545 use_devices_arg=[True, False])) 546 def testReduceToCpu(self, required_gpus, use_devices_arg): 547 strategy, _ = self._get_test_object( 548 None, None, required_gpus, use_devices_arg=use_devices_arg) 549 with strategy.scope(): 550 result = strategy.extended.call_for_each_replica(_replica_id_f32) 551 reduced = strategy.reduce(reduce_util.ReduceOp.SUM, result, axis=None) 552 expected = sum(range(strategy.num_replicas_in_sync)) 553 self.assertEqual(expected, self.evaluate(reduced)) 554 555 @combinations.generate( 556 combinations.combine( 557 mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) 558 def testAllReduceSum(self, required_gpus, use_devices_arg): 559 distribution, target = self._get_test_object( 560 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 561 with self.cached_session(target=target): 562 self._test_all_reduce_sum(distribution) 563 564 @combinations.generate( 565 combinations.combine( 566 mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) 567 def testAllReduceSumGradients(self, required_gpus, use_devices_arg): 568 distribution, target = self._get_test_object( 569 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 570 with self.cached_session(target=target): 571 self._test_all_reduce_sum_gradients(distribution) 572 573 @combinations.generate( 574 combinations.combine( 575 mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) 576 def testAllReduceSumGradientTape(self, required_gpus, use_devices_arg): 577 distribution, target = self._get_test_object( 578 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 579 with self.cached_session(target=target): 580 self._test_all_reduce_sum_gradient_tape(distribution) 581 582 @combinations.generate( 583 combinations.combine( 584 mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) 585 def testAllReduceMean(self, required_gpus, use_devices_arg): 586 distribution, target = self._get_test_object( 587 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 588 with self.cached_session(target=target): 589 self._test_all_reduce_mean(distribution) 590 591 @combinations.generate( 592 combinations.combine( 593 mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) 594 def testAllReduceMeanGradients(self, required_gpus, use_devices_arg): 595 distribution, target = self._get_test_object( 596 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 597 with self.cached_session(target=target): 598 self._test_all_reduce_mean_gradients(distribution) 599 600 @combinations.generate( 601 combinations.combine( 602 mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) 603 def testAllReduceMeanGradientTape(self, required_gpus, use_devices_arg): 604 distribution, target = self._get_test_object( 605 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 606 with self.cached_session(target=target): 607 self._test_all_reduce_mean_gradient_tape(distribution) 608 609 @combinations.generate( 610 combinations.combine( 611 mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) 612 def testNumpyDataset(self, required_gpus, use_devices_arg): 613 strategy, target = self._get_test_object( 614 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 615 self._test_numpy_dataset( 616 strategy, session=self.cached_session(target=target)) 617 618 @combinations.generate( 619 combinations.combine( 620 mode=['eager'], required_gpus=2, use_devices_arg=[True, False])) 621 def testReplicateDataset(self, required_gpus, use_devices_arg): 622 strategy, _ = self._get_test_object( 623 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 624 dataset_fn = lambda: dataset_ops.Dataset.range(10) 625 expected_values = [[i, i + 1] for i in range(0, 10, 2)] 626 input_fn = self._input_fn_to_test_input_context( 627 dataset_fn, 628 expected_num_replicas_in_sync=required_gpus, 629 expected_num_input_pipelines=1, 630 expected_input_pipeline_id=0) 631 self._test_input_fn_iterable(strategy, input_fn, expected_values) 632 633 @combinations.generate( 634 combinations.combine(mode=['graph'], use_devices_arg=[True, False])) 635 def testDeepCopy(self, use_devices_arg): 636 distribution, _ = self._get_test_object( 637 None, None, use_devices_arg=use_devices_arg) 638 copy.deepcopy(distribution) 639 640 @combinations.generate( 641 combinations.combine( 642 mode=['graph', 'eager'], 643 required_gpus=[0, 1, 2], 644 use_devices_arg=[True, False])) 645 def testSummaryForReplicaZeroOnly(self, required_gpus, use_devices_arg): 646 strategy, target = self._get_test_object( 647 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 648 with self.cached_session(target=target): 649 self._test_summary_for_replica_zero_only(strategy) 650 651 @combinations.generate( 652 combinations.combine( 653 mode=['graph', 'eager'], 654 required_gpus=[0, 1, 2], 655 use_devices_arg=[True, False])) 656 def testTrainableVariables(self, required_gpus, use_devices_arg): 657 strategy, _ = self._get_test_object( 658 None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) 659 self._test_trainable_variable(strategy) 660 661 662class LogicalDeviceTest(test.TestCase, parameterized.TestCase): 663 664 @combinations.generate(combinations.combine(mode=['eager'], required_gpus=1)) 665 def testKeepLogicalDevice(self): 666 gpus = tf_config.list_physical_devices('GPU') 667 if len(gpus) > 1: 668 self.skipTest('Skip logical device test on multi GPUs, since partial GPU ' 669 'virtualization is not permitted.') 670 # Cannot change logical device after the context initialization. 671 context._reset_context() # pylint: disable=protected-access 672 cluster_spec = multi_worker_test_base.create_cluster_spec( 673 has_chief=False, num_workers=1) 674 resolver = cluster_resolver_lib.SimpleClusterResolver( 675 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 676 task_type='worker', 677 task_id=0) 678 679 logical_gpus = len(gpus) * 2 680 for i, device in enumerate(gpus): 681 n = (i + 1) * logical_gpus // len(gpus) - i * logical_gpus // len(gpus) 682 assert n > 0 # guaranteed if count >= len(devices) 683 configs = [] 684 for ordinal in range(n): 685 config = context.LogicalDeviceConfiguration( 686 memory_limit=64, 687 experimental_device_ordinal=ordinal) 688 configs.append(config) 689 690 tf_config.set_logical_device_configuration(device, configs) 691 692 collective_all_reduce_strategy.CollectiveAllReduceStrategy( 693 cluster_resolver=resolver) 694 # Since we create two logical GPUs out of the last GPU, there should be one 695 # more logical GPUs than physical GPUs. 696 self.assertLen(tf_config.list_logical_devices('GPU'), logical_gpus) 697 context._reset_context() # pylint: disable=protected-access 698 699 700@combinations.generate( 701 combinations.combine( 702 strategy=[ 703 strategy_combinations.multi_worker_mirrored_2x1_cpu, 704 strategy_combinations.multi_worker_mirrored_2x1_gpu, 705 strategy_combinations.multi_worker_mirrored_2x2_gpu, 706 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 707 ], 708 mode=['eager'])) 709class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase): 710 711 def setUp(self): 712 super().setUp() 713 if context.context().list_physical_devices('TPU'): 714 self.skipTest('Test not supported on TPUs') 715 716 def test_replica_id_in_sync_group(self, strategy): 717 718 def replica_fn(): 719 replica_ctx = distribution_strategy_context.get_replica_context() 720 return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id 721 722 results = test_util.gather(strategy, strategy.run(replica_fn)) 723 self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)), 724 results[0].numpy()) 725 self.assertAllEqual( 726 list(range(len(strategy.extended.worker_devices))) * 727 strategy.extended._num_workers, results[1].numpy()) 728 729 def test_deep_copy_not_allowed(self, strategy): 730 # Check health is disabled in tests by default. We need to enable it for 731 # this test to simulate the real world. 732 strategy.extended._start_check_health_thread() 733 try: 734 with self.assertRaisesRegex(ValueError, 'cannot be deep copied'): 735 copy.deepcopy(strategy) 736 with self.assertRaisesRegex(ValueError, 'cannot be deep copied'): 737 with ops.Graph().as_default(): 738 copy.deepcopy(strategy) 739 finally: 740 strategy.extended._stop_check_health_thread() 741 742 743class ExperimentalCompatibilityTest(test.TestCase): 744 745 def testIsInstance(self): 746 # It's not uncommon for people to special case MultiWorkerMirroredStrategy, 747 # so we need to make sure isinstance check works for combinations between 748 # the experimental and non-experimental endpoints. 749 strategy = CollectiveAllReduceStrategy() 750 experimental_strategy = _CollectiveAllReduceStrategyExperimental() 751 self.assertIsInstance(strategy, CollectiveAllReduceStrategy) 752 self.assertIsInstance(strategy, _CollectiveAllReduceStrategyExperimental) 753 self.assertIsInstance(experimental_strategy, CollectiveAllReduceStrategy) 754 self.assertIsInstance(experimental_strategy, 755 _CollectiveAllReduceStrategyExperimental) 756 757 def testName(self): 758 # Estimator checks the __name__ to special case MultiWorkerMirroredStrategy. 759 self.assertEqual(CollectiveAllReduceStrategy.__name__, 760 'CollectiveAllReduceStrategy') 761 self.assertEqual(_CollectiveAllReduceStrategyExperimental.__name__, 762 'CollectiveAllReduceStrategy') 763 764 765def _replica_id_f32(): 766 return math_ops.cast( 767 distribution_strategy_context.get_replica_context() 768 .replica_id_in_sync_group, dtypes.float32) 769 770 771if __name__ == '__main__': 772 # TODO(b/172304955): enable logical devices. 773 test_util.main(config_logical_devices=False) 774