1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the input_lib library.""" 16 17import collections 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.python import tf2 23from tensorflow.python.data.experimental.ops import data_service_ops 24from tensorflow.python.data.experimental.service import server_lib 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import options as options_lib 27from tensorflow.python.data.ops.options import AutoShardPolicy 28from tensorflow.python.distribute import combinations 29from tensorflow.python.distribute import device_util 30from tensorflow.python.distribute import distribute_lib 31from tensorflow.python.distribute import distribute_utils 32from tensorflow.python.distribute import input_lib 33from tensorflow.python.distribute import input_util 34from tensorflow.python.distribute import multi_worker_util 35from tensorflow.python.distribute import reduce_util 36from tensorflow.python.distribute import strategy_combinations 37from tensorflow.python.distribute import test_util 38from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.eager import test 42from tensorflow.python.framework import composite_tensor 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import extension_type 47from tensorflow.python.framework import ops 48from tensorflow.python.framework import sparse_tensor 49from tensorflow.python.framework import test_util as framework_test_util 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import math_ops 53from tensorflow.python.ops import sparse_ops 54from tensorflow.python.ops import variables 55from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib 56from tensorflow.python.util import nest 57 58 59class DistributedIteratorTestBase(test.TestCase): 60 61 # The passed input_context is to create a sharded dataset in between-graph 62 # case. 63 # TODO(yuefengz): rewrite the following method to make it less DRY. 64 def _wrap_iterator(self, 65 input_type, 66 dataset_or_input_fn, 67 input_workers, 68 devices, 69 num_replicas_in_sync, 70 strategy, 71 input_context=None): 72 # The `input_context` passed in is to shard dataset for 73 # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where 74 # multiple InputContexts are needed. 75 if input_type == "input_fn": 76 self.assertIsNone( 77 input_context, 78 msg=("`The input_context` arg is only used to shard dataset in " 79 "`MultiWorkerMirroredStrategy` when the input type is dataset.")) 80 81 input_contexts = [] 82 for i in range(input_workers.num_workers): 83 input_contexts.append( 84 distribute_lib.InputContext( 85 # Note: `input_workers.num_workers` is always 1 in between-graph 86 # case. 87 num_input_pipelines=input_workers.num_workers, 88 input_pipeline_id=i, 89 num_replicas_in_sync=len(devices))) 90 91 iterator = input_lib_v1.InputFunctionIterator(dataset_or_input_fn, 92 input_workers, 93 input_contexts, strategy) 94 else: 95 iterator = input_lib_v1.DatasetIterator( 96 dataset_or_input_fn, 97 input_workers, 98 strategy, 99 num_replicas_in_sync=num_replicas_in_sync, 100 input_context=input_context) 101 return iterator 102 103 def _wrap_dataset(self, 104 input_type, 105 dataset, 106 input_workers, 107 num_replicas_in_sync, 108 strategy, 109 input_context=None): 110 if input_type == "dataset": 111 if tf2.enabled(): 112 return input_lib.DistributedDataset( 113 input_workers, 114 strategy, 115 dataset, 116 num_replicas_in_sync=num_replicas_in_sync, 117 input_context=input_context) 118 else: 119 return input_lib_v1.DistributedDatasetV1( 120 dataset, 121 input_workers, 122 strategy, 123 num_replicas_in_sync=num_replicas_in_sync, 124 input_context=input_context) 125 else: 126 return strategy.distribute_datasets_from_function(dataset) 127 128 def _assert_iterator_values(self, 129 iterator, 130 expected_values, 131 evaluate_fn, 132 devices, 133 enable_get_next_as_optional=False): 134 actual_values = [] 135 for _ in range(len(expected_values)): 136 if enable_get_next_as_optional: 137 next_element = iterator.get_next_as_optional().get_value() 138 else: 139 next_element = iterator.get_next() 140 computed_value = evaluate_fn([ 141 distribute_utils.select_replica(r, next_element) 142 for r in range(len(devices)) 143 ]) 144 actual_values.append(computed_value) 145 for expected_value, actual_value in zip(expected_values, actual_values): 146 for expected, actual in zip(expected_value, actual_value): 147 self.assertAllEqual(expected, actual) 148 149 def _assert_dataset_values_for_loop(self, dataset, expected_values, 150 evaluate_fn, devices): 151 actual_values = [] 152 for x in dataset: 153 computed_value = self.evaluate( 154 [distribute_utils.select_replica(r, x) for r in range(len(devices))]) 155 actual_values.append(computed_value) 156 for expected_value, actual_value in zip(expected_values, actual_values): 157 for expected, actual in zip(expected_value, actual_value): 158 self.assertAllEqual(expected, actual) 159 160 def _test_input_iteration(self, 161 input_type, 162 api_type, 163 iteration_type, 164 dataset_or_input_fn, 165 worker_device_pairs, 166 expected_values, 167 strategy, 168 sess=None, 169 num_replicas_in_sync=None, 170 input_context=None): 171 if iteration_type == "for_loop" and not context.executing_eagerly(): 172 self.skipTest("unsupported test combination.") 173 174 if api_type == "wrap_into_iterator" and iteration_type == "for_loop": 175 self.skipTest("unsupported test combination.") 176 177 if api_type == "wrap_into_iterator" and input_type == "input_fn": 178 self.skipTest("unsupported test combination.") 179 180 devices = nest.flatten([ds for _, ds in worker_device_pairs]) 181 input_workers = input_lib.InputWorkers(worker_device_pairs) 182 183 if api_type == "wrap_into_iterator": 184 iterator = self._wrap_iterator( 185 input_type, 186 dataset_or_input_fn, 187 input_workers, 188 devices, 189 num_replicas_in_sync, 190 strategy, 191 input_context=input_context) 192 else: 193 # wrapping into a dataset: 194 dataset = self._wrap_dataset( 195 input_type, 196 dataset_or_input_fn, 197 input_workers, 198 num_replicas_in_sync, 199 strategy, 200 input_context=input_context) 201 202 if ops.executing_eagerly_outside_functions(): 203 iterator = iter(dataset) 204 else: 205 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 206 iterator = dataset.make_initializable_iterator() 207 else: 208 self.skipTest("unsupported test combination") 209 210 if isinstance(iterator, composite_tensor.CompositeTensor): 211 nest.assert_same_structure( 212 iterator, iterator._type_spec, expand_composites=True) 213 214 if iteration_type == "get_next": 215 evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) 216 if not ops.executing_eagerly_outside_functions(): 217 evaluate(control_flow_ops.group(iterator.initializer)) 218 219 def test_get_next(iterator): 220 self._assert_iterator_values(iterator, expected_values, evaluate, 221 devices) 222 223 with self.assertRaises(errors.OutOfRangeError): 224 self._assert_iterator_values(iterator, expected_values, evaluate, 225 devices) 226 227 # After re-initializing the iterator, should be able to iterate again. 228 if not ops.executing_eagerly_outside_functions(): 229 evaluate(control_flow_ops.group(iterator.initializer)) 230 else: 231 if api_type == "wrap_into_iterator": 232 self.skipTest("unsupported test combination") 233 else: 234 iterator = iter(dataset) 235 236 self._assert_iterator_values(iterator, expected_values, evaluate, 237 devices) 238 239 def test_get_next_as_optional(iterator): 240 self._assert_iterator_values( 241 iterator, 242 expected_values, 243 evaluate, 244 devices, 245 enable_get_next_as_optional=True) 246 247 next_element = iterator.get_next_as_optional() 248 self.assertFalse(self.evaluate(next_element.has_value())) 249 with self.assertRaises(errors.InvalidArgumentError): 250 self._assert_iterator_values( 251 iterator, [0], 252 evaluate, 253 devices, 254 enable_get_next_as_optional=True) 255 256 test_get_next(iterator) 257 258 # re-initializing the iterator 259 if not tf2.enabled(): 260 # TODO(yuefengz): we should split this function. 261 return 262 else: 263 if api_type == "wrap_into_iterator": 264 return 265 else: 266 iterator = iter(dataset) 267 268 test_get_next_as_optional(iterator) 269 270 if iteration_type == "for_loop" and context.executing_eagerly(): 271 self._assert_dataset_values_for_loop(dataset, expected_values, 272 self.evaluate, devices) 273 274 def _create_dataset_or_input_fn(self, input_type, input_fn): 275 if input_type == "input_fn": 276 return input_fn 277 else: 278 return input_fn(distribute_lib.InputContext()) 279 280 281class DistributedIteratorTest(DistributedIteratorTestBase, 282 parameterized.TestCase): 283 284 @combinations.generate( 285 combinations.combine( 286 mode=["eager"], 287 distribution=[ 288 strategy_combinations.mirrored_strategy_with_gpu_and_cpu 289 ])) 290 def testMultiDeviceIterInitialize(self, distribution): 291 if tf2.enabled(): 292 self.skipTest("Only V1 is supported.") 293 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 294 "/device:CPU:0"])] 295 dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) 296 297 input_workers = input_lib.InputWorkers(worker_device_pairs) 298 299 dist_dataset = input_util.get_distributed_dataset( 300 dataset_fn(distribute_lib.InputContext()), input_workers, distribution) 301 302 iterator = dataset_ops.make_one_shot_iterator(dist_dataset) 303 304 @def_function.function 305 def init_func_for_iter(): 306 self.evaluate(iterator.initializer) 307 308 init_func_for_iter() 309 310 @combinations.generate( 311 combinations.combine( 312 mode=["graph", "eager"], 313 input_type=["input_fn", "dataset"], 314 api_type=["wrap_into_iterator", "wrap_into_dataset"], 315 iteration_type=["get_next", "for_loop"], 316 distribution=[ 317 strategy_combinations.one_device_strategy, 318 strategy_combinations.mirrored_strategy_with_one_cpu, 319 ], 320 enable_get_next_as_optional=[True, False])) 321 def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, 322 enable_get_next_as_optional): 323 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 324 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 325 dataset_or_input_fn = self._create_dataset_or_input_fn( 326 input_type, dataset_fn) 327 328 expected_values = [[i] for i in range(10)] 329 330 distribution.extended.experimental_enable_get_next_as_optional = ( 331 enable_get_next_as_optional) 332 self._test_input_iteration(input_type, api_type, iteration_type, 333 dataset_or_input_fn, worker_device_pairs, 334 expected_values, distribution) 335 336 @combinations.generate( 337 combinations.combine( 338 mode=["eager"], 339 input_type=["input_fn", "dataset"], 340 api_type=["wrap_into_dataset"], 341 iteration_type=["get_next", "for_loop"], 342 distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu], 343 enable_get_next_as_optional=[True, False])) 344 def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type, 345 distribution, enable_get_next_as_optional): 346 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 347 dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) 348 dataset_or_input_fn = self._create_dataset_or_input_fn( 349 input_type, dataset_fn) 350 351 expected_values = [[i] for i in range(10)] 352 353 distribution.extended.experimental_enable_get_next_as_optional = ( 354 enable_get_next_as_optional) 355 self._test_input_iteration(input_type, api_type, iteration_type, 356 dataset_or_input_fn, worker_device_pairs, 357 expected_values, distribution) 358 359 @combinations.generate( 360 combinations.combine( 361 mode=["graph", "eager"], 362 input_type=["input_fn", "dataset"], 363 api_type=["wrap_into_iterator", "wrap_into_dataset"], 364 iteration_type=["get_next", "for_loop"], 365 distribution=[ 366 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 367 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 368 ], 369 enable_get_next_as_optional=[True, False])) 370 def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type, 371 distribution, enable_get_next_as_optional): 372 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 373 "/device:CPU:0"])] 374 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 375 dataset_or_input_fn = self._create_dataset_or_input_fn( 376 input_type, dataset_fn) 377 378 expected_values = [[i, i + 1] for i in range(0, 10, 2)] 379 380 distribution.extended.experimental_enable_get_next_as_optional = ( 381 enable_get_next_as_optional) 382 self._test_input_iteration(input_type, api_type, iteration_type, 383 dataset_or_input_fn, worker_device_pairs, 384 expected_values, distribution) 385 386 @combinations.generate( 387 combinations.combine( 388 mode=["graph", "eager"], 389 input_type=["input_fn", "dataset"], 390 api_type=["wrap_into_iterator", "wrap_into_dataset"], 391 iteration_type=["get_next", "for_loop"], 392 distribution=[strategy_combinations.tpu_strategy], 393 enable_get_next_as_optional=[True, False])) 394 def testTPU(self, input_type, api_type, iteration_type, distribution, 395 enable_get_next_as_optional): 396 worker_device_pairs = collections.OrderedDict() 397 for tpu_device in distribution.extended.worker_devices: 398 host_device = device_util.get_host_for_device(tpu_device) 399 worker_device_pairs.setdefault(host_device, []) 400 worker_device_pairs[host_device].append(tpu_device) 401 worker_device_pairs = worker_device_pairs.items() 402 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 403 dataset_or_input_fn = self._create_dataset_or_input_fn( 404 input_type, dataset_fn) 405 406 expected_values = [[i, i + 1] for i in range(0, 10, 2)] 407 408 distribution.extended.experimental_enable_get_next_as_optional = ( 409 enable_get_next_as_optional) 410 self._test_input_iteration(input_type, api_type, iteration_type, 411 dataset_or_input_fn, worker_device_pairs, 412 expected_values, distribution) 413 414 @combinations.generate( 415 combinations.combine( 416 mode=["graph", "eager"], 417 input_type=["input_fn", "dataset"], 418 api_type=["wrap_into_iterator", "wrap_into_dataset"], 419 iteration_type=["get_next", "for_loop"], 420 distribution=[ 421 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 422 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 423 ], 424 enable_get_next_as_optional=[True, False])) 425 def testTupleDataset(self, input_type, api_type, iteration_type, distribution, 426 enable_get_next_as_optional): 427 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 428 "/device:CPU:0"])] 429 430 def dataset_fn(ctx): 431 del ctx 432 dataset1 = dataset_ops.Dataset.range(10) 433 dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) 434 return dataset_ops.Dataset.zip((dataset1, dataset2)) 435 436 dataset_or_input_fn = self._create_dataset_or_input_fn( 437 input_type, dataset_fn) 438 439 expected_values = [ 440 [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2) 441 ] 442 443 distribution.extended.experimental_enable_get_next_as_optional = ( 444 enable_get_next_as_optional) 445 self._test_input_iteration(input_type, api_type, iteration_type, 446 dataset_or_input_fn, worker_device_pairs, 447 expected_values, distribution) 448 449 @combinations.generate( 450 combinations.combine( 451 mode=["eager"], 452 input_type=["input_fn", "dataset"], 453 api_type=["wrap_into_dataset"], 454 iteration_type=["get_next", "for_loop"], 455 distribution=[ 456 strategy_combinations.multi_worker_mirrored_2x2_gpu, 457 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call 458 ], 459 enable_get_next_as_optional=[True, False])) 460 def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type, 461 distribution, enable_get_next_as_optional): 462 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 463 "/device:GPU:1"])] 464 465 def dataset_fn(ctx): 466 del ctx 467 dataset1 = dataset_ops.Dataset.range(10) 468 dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) 469 return dataset_ops.Dataset.zip((dataset1, dataset2)) 470 471 dataset_or_input_fn = self._create_dataset_or_input_fn( 472 input_type, dataset_fn) 473 474 expected_values = [ 475 [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2) 476 ] 477 478 distribution.extended.experimental_enable_get_next_as_optional = ( 479 enable_get_next_as_optional) 480 481 # Input_context is not passed in and thus no sharding. 482 self._test_input_iteration(input_type, api_type, iteration_type, 483 dataset_or_input_fn, worker_device_pairs, 484 expected_values, distribution) 485 486 @combinations.generate( 487 combinations.combine( 488 mode=["eager"], 489 distribution=[ 490 strategy_combinations.one_device_strategy, 491 strategy_combinations.mirrored_strategy_with_one_cpu, 492 strategy_combinations.multi_worker_mirrored_2x1_cpu, 493 ])) 494 def testIterableIterator(self, distribution): 495 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 496 input_workers = input_lib.InputWorkers(worker_device_pairs) 497 498 dataset = dataset_ops.Dataset.range(10) 499 dist_dataset = input_util.get_distributed_dataset(dataset, input_workers, 500 distribution) 501 502 iterator = iter(dist_dataset) 503 for i, element in enumerate(iterator): 504 self.assertAllEqual(distribution.experimental_local_results(element), [i]) 505 506 @combinations.generate( 507 combinations.combine( 508 mode=["eager"], 509 distribution=[ 510 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 511 strategy_combinations.mirrored_strategy_with_one_cpu, 512 ])) 513 def testIterableIteratorError(self, distribution): 514 dataset = dataset_ops.Dataset.range(10).batch(2) 515 dist_dataset = distribution.experimental_distribute_dataset(dataset) 516 517 iterator = iter(dist_dataset) 518 # Raises error when next(iterator) is called without strategy scope 519 with self.assertRaises(ValueError): 520 521 def replica_fn1(iterator): 522 return next(iterator) 523 524 distribution.run(replica_fn1, args=(iterator,)) 525 526 if distribution.num_replicas_in_sync == 1: 527 expected_result = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8, 9]]] 528 elif distribution.num_replicas_in_sync == 2: 529 expected_result = [[[0], [1]], [[2], [3]], [[4], [5]], [[6], [7]], 530 [[8], [9]]] 531 532 with distribution.scope(): 533 534 def replica_fn2(iterator): 535 return iterator 536 537 result = distribution.run(replica_fn2, args=(next(iterator),)) 538 self.assertAllEqual( 539 distribution.experimental_local_results(result), expected_result[0]) 540 541 # Confirm default ReplicaContext also works 542 iterator = iter(dist_dataset) 543 for i, element in enumerate(iterator): 544 self.assertAllEqual( 545 distribution.experimental_local_results(element), expected_result[i]) 546 547 @combinations.generate( 548 combinations.combine( 549 mode=["graph", "eager"], 550 input_type=["input_fn", "dataset"], 551 api_type=["wrap_into_iterator", "wrap_into_dataset"], 552 iteration_type=["get_next", "for_loop"], 553 drop_remainder=[True, False], 554 distribution=[ 555 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 556 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 557 ])) 558 def testUnevenDatasetBatches(self, input_type, api_type, iteration_type, 559 drop_remainder, distribution): 560 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 561 "/device:CPU:0"])] 562 dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch( # pylint: disable=g-long-lambda 563 2, drop_remainder=drop_remainder) 564 dataset_or_input_fn = self._create_dataset_or_input_fn( 565 input_type, dataset_fn) 566 567 # The last global batch only contains data for one replica. 568 if drop_remainder: 569 expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 570 else: 571 expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] 572 distribution.extended.experimental_enable_get_next_as_optional = True 573 self._test_input_iteration(input_type, api_type, iteration_type, 574 dataset_or_input_fn, worker_device_pairs, 575 expected_values, distribution) 576 577 @combinations.generate( 578 combinations.combine( 579 mode=["eager"], 580 input_type=["input_fn", "dataset"], 581 api_type=["wrap_into_dataset"], 582 iteration_type=["get_next", "for_loop"], 583 drop_remainder=[True, False], 584 distribution=[ 585 strategy_combinations.multi_worker_mirrored_2x1_cpu, 586 strategy_combinations.multi_worker_mirrored_2x1_gpu, 587 ])) 588 def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type, 589 iteration_type, drop_remainder, 590 distribution): 591 # Actual devices don't matter in this test as long as the number of global 592 # repices is 2. 593 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 594 cr = distribution.cluster_resolver 595 self.assertIsNotNone(cr) 596 worker_count = multi_worker_util.worker_count(cr.cluster_spec(), 597 cr.task_type) 598 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 599 cr.task_type, cr.task_id) 600 601 def dataset_fn(_): 602 dataset = dataset_ops.Dataset.range(9) 603 604 if input_type == "input_fn": 605 # When input_fn is used, there is no automatic rebatching and sharding, 606 # so we add them here. 607 return dataset.shard(worker_count, id_in_cluster).batch(1) 608 else: 609 return dataset.batch(2, drop_remainder=drop_remainder) 610 611 dataset_or_input_fn = self._create_dataset_or_input_fn( 612 input_type, dataset_fn) 613 614 if drop_remainder and input_type == "dataset": 615 if id_in_cluster == 0: 616 expected_values = [[[0]], [[2]], [[4]], [[6]]] 617 else: 618 expected_values = [[[1]], [[3]], [[5]], [[7]]] 619 else: 620 # The last global batch only contains data for one replica. 621 if id_in_cluster == 0: 622 expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]] 623 else: 624 expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]] 625 distribution.extended.experimental_enable_get_next_as_optional = True 626 self._test_input_iteration( 627 input_type, 628 api_type, 629 iteration_type, 630 dataset_or_input_fn, 631 worker_device_pairs, 632 expected_values, 633 distribution, 634 num_replicas_in_sync=distribution.num_replicas_in_sync, 635 input_context=distribution.extended._make_input_context()) 636 637 @combinations.generate( 638 combinations.combine( 639 mode=["eager"], 640 input_type=["input_fn", "dataset"], 641 api_type=["wrap_into_dataset"], 642 iteration_type=["get_next", "for_loop"], 643 drop_remainder=[True, False], 644 distribution=[ 645 strategy_combinations.multi_worker_mirrored_2x2_gpu, 646 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call 647 ])) 648 def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type, 649 api_type, iteration_type, 650 drop_remainder, 651 distribution): 652 # Actual devices don't matter in this test as long as the number of global 653 # repices is 2. 654 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 655 "/device:GPU:1"])] 656 cr = distribution.cluster_resolver 657 self.assertIsNotNone(cr) 658 worker_count = multi_worker_util.worker_count(cr.cluster_spec(), 659 cr.task_type) 660 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 661 cr.task_type, cr.task_id) 662 663 def dataset_fn(_): 664 dataset = dataset_ops.Dataset.range(15) 665 666 if input_type == "input_fn": 667 # When input_fn is used, there is no automatic rebatching and sharding, 668 # so we add them here. 669 return dataset.shard(worker_count, id_in_cluster).batch(1) 670 else: 671 return dataset.batch(4, drop_remainder=drop_remainder) 672 673 dataset_or_input_fn = self._create_dataset_or_input_fn( 674 input_type, dataset_fn) 675 676 # The last global batch only contains data for one replica. 677 if drop_remainder and input_type == "dataset": 678 if id_in_cluster == 0: 679 expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]] 680 else: 681 expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]] 682 else: 683 if id_in_cluster == 0: 684 expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]] 685 else: 686 expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]] 687 distribution.extended.experimental_enable_get_next_as_optional = True 688 self._test_input_iteration( 689 input_type, 690 api_type, 691 iteration_type, 692 dataset_or_input_fn, 693 worker_device_pairs, 694 expected_values, 695 distribution, 696 num_replicas_in_sync=distribution.num_replicas_in_sync, 697 input_context=distribution.extended._make_input_context()) 698 699 @combinations.generate( 700 combinations.combine( 701 mode=["graph", "eager"], 702 input_type=["dataset"], 703 api_type=["wrap_into_iterator", "wrap_into_dataset"], 704 iteration_type=["get_next", "for_loop"], 705 num_replicas_in_sync=[None, 2], 706 distribution=[ 707 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 708 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 709 ], 710 enable_get_next_as_optional=[True, False])) 711 def testBatchSplitting(self, input_type, api_type, iteration_type, 712 num_replicas_in_sync, distribution, 713 enable_get_next_as_optional): 714 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 715 "/device:CPU:0"])] 716 batch_size = 10 717 dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) 718 dataset_or_input_fn = self._create_dataset_or_input_fn( 719 input_type, dataset_fn) 720 721 updated_batch_size = ( 722 batch_size // 723 num_replicas_in_sync if num_replicas_in_sync else batch_size) 724 expected_values = [[ 725 range(i, i + updated_batch_size), 726 range(i + updated_batch_size, i + 2 * updated_batch_size) 727 ] for i in range(0, 100, updated_batch_size * 2)] 728 729 distribution.extended.experimental_enable_get_next_as_optional = ( 730 enable_get_next_as_optional) 731 self._test_input_iteration( 732 input_type, 733 api_type, 734 iteration_type, 735 dataset_or_input_fn, 736 worker_device_pairs, 737 expected_values, 738 distribution, 739 sess=None, 740 num_replicas_in_sync=num_replicas_in_sync) 741 742 @combinations.generate( 743 combinations.combine( 744 mode=["eager"], 745 input_type=["dataset"], 746 api_type=["wrap_into_dataset"], 747 iteration_type=["get_next", "for_loop"], 748 num_replicas_in_sync=[None, 2], 749 distribution=[ 750 strategy_combinations.multi_worker_mirrored_2x2_gpu, 751 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call 752 ], 753 enable_get_next_as_optional=[True, False])) 754 def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type, 755 num_replicas_in_sync, distribution, 756 enable_get_next_as_optional): 757 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 758 "/device:GPU:1"])] 759 batch_size = 10 760 cr = distribution.cluster_resolver 761 self.assertIsNotNone(cr) 762 763 def dataset_fn(_): 764 dataset = dataset_ops.Dataset.range(100).batch(batch_size) 765 return dataset 766 767 dataset_or_input_fn = self._create_dataset_or_input_fn( 768 input_type, dataset_fn) 769 770 updated_batch_size = ( 771 batch_size // 772 num_replicas_in_sync if num_replicas_in_sync else batch_size) 773 expected_values = [ 774 [ # pylint: disable=g-complex-comprehension 775 range(i, i + updated_batch_size), 776 range(i + updated_batch_size, i + 2 * updated_batch_size) 777 ] for i in range(0, 100, updated_batch_size * 2) 778 ] 779 780 distribution.extended.experimental_enable_get_next_as_optional = ( 781 enable_get_next_as_optional) 782 self._test_input_iteration( 783 input_type, 784 api_type, 785 iteration_type, 786 dataset_or_input_fn, 787 worker_device_pairs, 788 expected_values, 789 distribution, 790 sess=None, 791 num_replicas_in_sync=num_replicas_in_sync) 792 793 @combinations.generate( 794 combinations.combine( 795 mode=["eager"], 796 distribution=[ 797 strategy_combinations.one_device_strategy, 798 strategy_combinations.mirrored_strategy_with_one_cpu, 799 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 800 strategy_combinations.tpu_strategy, 801 strategy_combinations.central_storage_strategy_with_two_gpus, 802 strategy_combinations.multi_worker_mirrored_2x2_gpu, 803 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 804 strategy_combinations.multi_worker_mirrored_2x1_cpu, 805 ], 806 )) 807 def testCacheAcrossIteration(self, distribution): 808 if not tf2.enabled(): 809 self.skipTest("Only V2 is supported.") 810 811 dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4) 812 dist_dataset = distribution.experimental_distribute_dataset(dataset) 813 814 first_epoch = list( 815 distribution.experimental_local_results(x) for x in dist_dataset) 816 second_epoch = list( 817 distribution.experimental_local_results(x) for x in dist_dataset) 818 819 self.assertAllEqual(first_epoch, second_epoch) 820 821 @combinations.generate( 822 combinations.combine( 823 mode=["eager"], 824 distribution=[ 825 strategy_combinations.one_device_strategy, 826 strategy_combinations.mirrored_strategy_with_one_cpu, 827 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 828 strategy_combinations.tpu_strategy, 829 strategy_combinations.central_storage_strategy_with_two_gpus, 830 strategy_combinations.multi_worker_mirrored_2x2_gpu, 831 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 832 strategy_combinations.multi_worker_mirrored_2x1_cpu, 833 ], 834 reshuffle=[True, False])) 835 def testShuffleAcrossIterations(self, distribution, reshuffle): 836 if not tf2.enabled(): 837 self.skipTest("Only V2 is supported.") 838 839 dataset = dataset_ops.Dataset.range(12).shuffle( 840 12, reshuffle_each_iteration=reshuffle).batch(4) 841 dist_dataset = distribution.experimental_distribute_dataset(dataset) 842 843 first_epoch = list( 844 distribution.experimental_local_results(x) for x in dist_dataset) 845 second_epoch = list( 846 distribution.experimental_local_results(x) for x in dist_dataset) 847 848 if reshuffle: 849 self.assertNotAllEqual(first_epoch, second_epoch) 850 else: 851 self.assertAllEqual(first_epoch, second_epoch) 852 853 @combinations.generate( 854 combinations.combine( 855 mode=["eager"], 856 distribution=[ 857 strategy_combinations.one_device_strategy, 858 strategy_combinations.mirrored_strategy_with_one_cpu, 859 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 860 strategy_combinations.tpu_strategy, 861 strategy_combinations.central_storage_strategy_with_two_gpus, 862 strategy_combinations.multi_worker_mirrored_2x2_gpu, 863 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 864 strategy_combinations.multi_worker_mirrored_2x1_cpu, 865 ])) 866 def testGetNextOptionalShapeFinite(self, distribution): 867 batch_size = 8 868 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 869 "feature": array_ops.ones([batch_size, 10]), 870 "label": array_ops.ones([batch_size]), 871 }) 872 dataset = dataset.batch(batch_size, drop_remainder=True) 873 dist_dataset = distribution.experimental_distribute_dataset(dataset) 874 875 @def_function.function 876 def train_fn(): 877 for data in dist_dataset: 878 data = nest.map_structure(distribution.experimental_local_results, data) 879 feature = data["feature"] 880 label = data["label"] 881 882 # Assert the shapes are still static from all replicas. 883 for replica_id in range(len(distribution.extended.worker_devices)): 884 self.assertEqual([None, 10], 885 feature[replica_id].shape.as_list()) 886 self.assertEqual([None], label[replica_id].shape.as_list()) 887 888 train_fn() 889 890 @combinations.generate( 891 combinations.combine( 892 mode=["eager"], 893 distribution=[ 894 strategy_combinations.one_device_strategy, 895 strategy_combinations.mirrored_strategy_with_one_cpu, 896 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 897 strategy_combinations.tpu_strategy, 898 strategy_combinations.central_storage_strategy_with_two_gpus, 899 strategy_combinations.multi_worker_mirrored_2x2_gpu, 900 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 901 strategy_combinations.multi_worker_mirrored_2x1_cpu, 902 ])) 903 def testGetNextOptionalShapeInfinite(self, distribution): 904 batch_size = 8 905 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 906 "feature": array_ops.ones([batch_size, 10]), 907 "label": array_ops.ones([batch_size]), 908 }) 909 dataset = dataset.batch(batch_size, drop_remainder=True) 910 dataset = dataset.repeat() 911 dist_dataset = distribution.experimental_distribute_dataset(dataset) 912 per_replica_batch_size = batch_size // distribution.num_replicas_in_sync 913 914 @def_function.function 915 def train_fn(): 916 data = iter(dist_dataset).get_next_as_optional().get_value() 917 data = nest.map_structure(distribution.experimental_local_results, data) 918 feature = data["feature"] 919 label = data["label"] 920 921 # Assert the shapes are still static from all replicas. 922 for replica_id in range(len(distribution.extended.worker_devices)): 923 self.assertEqual([per_replica_batch_size, 10], 924 feature[replica_id].shape.as_list()) 925 self.assertEqual([per_replica_batch_size], 926 label[replica_id].shape.as_list()) 927 928 train_fn() 929 930 @combinations.generate( 931 combinations.combine( 932 mode=["eager"], 933 distribution=[ 934 strategy_combinations.one_device_strategy, 935 strategy_combinations.mirrored_strategy_with_one_cpu, 936 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 937 strategy_combinations.tpu_strategy, 938 strategy_combinations.central_storage_strategy_with_two_gpus, 939 strategy_combinations.multi_worker_mirrored_2x2_gpu, 940 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 941 strategy_combinations.multi_worker_mirrored_2x1_cpu, 942 ])) 943 def testGetNextOptionalShapeEmpty(self, distribution): 944 batch_size = 8 945 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 946 "feature": array_ops.ones([batch_size, 10]), 947 "label": array_ops.ones([batch_size]), 948 }) 949 dataset = dataset.batch(batch_size, drop_remainder=True) 950 dataset = dataset.repeat() 951 dist_dataset = distribution.experimental_distribute_dataset(dataset) 952 per_replica_batch_size = batch_size // distribution.num_replicas_in_sync 953 954 @def_function.function 955 def train_fn(): 956 data = iter(dist_dataset).get_next_as_optional() 957 feature_specs = data.element_spec["feature"]._component_specs 958 value_specs = data.element_spec["label"]._component_specs 959 if not isinstance(feature_specs, tuple): 960 feature_specs = (feature_specs,) 961 value_specs = (value_specs,) 962 # Assert the shapes are still static from all replicas. 963 for replica_id in range(len(distribution.extended.worker_devices)): 964 self.assertEqual([per_replica_batch_size, 10], 965 feature_specs[replica_id].shape.as_list()) 966 self.assertEqual([per_replica_batch_size], 967 value_specs[replica_id].shape.as_list()) 968 969 train_fn() 970 971 @combinations.generate( 972 combinations.combine( 973 mode=["eager"], 974 distribution=[ 975 strategy_combinations.multi_worker_mirrored_2x1_cpu, 976 ], 977 input_type=["dataset"], 978 api_type=["wrap_into_iterator", "wrap_into_dataset"], 979 iteration_type=["get_next", "for_loop"], 980 auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF])) 981 def testAutoshardingOption(self, distribution, input_type, api_type, 982 iteration_type, auto_shard_policy): 983 cr = distribution.cluster_resolver 984 self.assertIsNotNone(cr) 985 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 986 cr.task_type, cr.task_id) 987 ds_option = options_lib.Options() 988 ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy 989 dataset_fn = ( 990 lambda _: dataset_ops.Dataset.range(4).with_options(ds_option)) 991 dataset_or_input_fn = self._create_dataset_or_input_fn( 992 input_type, dataset_fn) 993 994 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 995 if auto_shard_policy == AutoShardPolicy.AUTO: 996 if id_in_cluster == 0: 997 expected_values = [[0], [2]] 998 else: 999 expected_values = [[1], [3]] 1000 else: 1001 expected_values = [[0], [1], [2], [3]] 1002 self._test_input_iteration( 1003 input_type, 1004 api_type, 1005 iteration_type, 1006 dataset_or_input_fn, 1007 worker_device_pairs, 1008 expected_values, 1009 distribution, 1010 input_context=distribution.extended._make_input_context()) 1011 1012 @combinations.generate( 1013 combinations.combine( 1014 mode=["eager"], 1015 distribution=[ 1016 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1017 ], 1018 input_type=["input_fn"], 1019 api_type=["wrap_into_dataset"], 1020 iteration_type=["get_next", "for_loop"])) 1021 def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type, 1022 iteration_type): 1023 cr = distribution.cluster_resolver 1024 self.assertIsNotNone(cr) 1025 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 1026 cr.task_type, cr.task_id) 1027 1028 def dataset_fn(ctx): 1029 if ctx.input_pipeline_id == 0: 1030 return dataset_ops.Dataset.range(8).batch(2) 1031 else: 1032 return dataset_ops.Dataset.range(9).batch(2) 1033 1034 dataset_or_input_fn = self._create_dataset_or_input_fn( 1035 input_type, dataset_fn) 1036 1037 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1038 1039 if id_in_cluster == 0: 1040 expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]] 1041 else: 1042 expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]] 1043 distribution.extended.experimental_enable_get_next_as_optional = True 1044 self._test_input_iteration(input_type, api_type, iteration_type, 1045 dataset_or_input_fn, worker_device_pairs, 1046 expected_values, distribution) 1047 1048 @combinations.generate( 1049 combinations.combine( 1050 strategy=[ 1051 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1052 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1053 ], 1054 mode=["eager"])) 1055 def testLoopOverDatasetInTFFunction(self, strategy): 1056 dataset = dataset_ops.Dataset.range(10).map(lambda x: { # pylint: disable=g-long-lambda 1057 "y": math_ops.cast(x, dtypes.float32) ** 2, 1058 }).batch(4) 1059 dist_dataset = strategy.experimental_distribute_dataset(dataset) 1060 1061 with strategy.scope(): 1062 v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM) 1063 1064 @def_function.function 1065 def iterator_fn(dist_dataset): 1066 1067 def assign_add_fn(data): 1068 v.assign_add(math_ops.reduce_sum(data["y"])) 1069 1070 for data in dist_dataset: 1071 strategy.run(assign_add_fn, args=(data,)) 1072 1073 iterator_fn(dist_dataset) 1074 self.assertEqual(v.numpy(), 285.0) 1075 1076 1077class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, 1078 parameterized.TestCase): 1079 """Tests for DistributedDataset with non-dense tensors.""" 1080 1081 @combinations.generate( 1082 combinations.combine( 1083 mode=["eager"], 1084 distribution=[ 1085 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1086 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1087 strategy_combinations.multi_worker_mirrored_2x2_gpu, 1088 ], 1089 input_type=["dataset", "input_fn"], 1090 drop_remainder=[False, True], 1091 defun_type=["lambda", "tf_function"], 1092 )) 1093 def testRaggedSparse(self, distribution, input_type, drop_remainder, 1094 defun_type): 1095 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1096 self.skipTest("b/213596871, b/214574707") 1097 1098 if not tf2.enabled(): 1099 self.skipTest("Only V2 is supported.") 1100 1101 defun = { 1102 "lambda": lambda f: f, 1103 "tf_function": def_function.function 1104 }[defun_type] 1105 distribution.extended.experimental_enable_get_next_as_optional = True 1106 global_batch_size = 8 1107 1108 def dataset_fn(ctx=None): 1109 ctx = ctx or distribute_lib.InputContext() 1110 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1111 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1112 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1113 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1114 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1115 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1116 "dense": ragged_tensor.to_tensor(), 1117 "ragged": ragged_tensor, 1118 "sparse": ragged_tensor.to_sparse(), 1119 }) 1120 dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) 1121 return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1122 1123 dataset_or_input_fn = self._create_dataset_or_input_fn( 1124 input_type, dataset_fn) 1125 dataset = self._wrap_dataset(input_type, dataset_or_input_fn, 1126 distribution.extended._input_workers, 1127 distribution.num_replicas_in_sync, 1128 distribution) 1129 # Assert that the tensors are rebatched and sparsity is preserved. 1130 per_replica_batch = defun(lambda x: next(iter(x)))(dataset) 1131 self.assertAllEqual( 1132 distribute_utils.select_replica(0, per_replica_batch["dense"]), 1133 [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]]) 1134 self.assertAllEqual( 1135 distribute_utils.select_replica(1, per_replica_batch["dense"]), 1136 [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]]) 1137 # Transitively check the ragged and sparse tensors by densification. 1138 for i in range(2): 1139 self.assertLen( 1140 distribute_utils.select_replica(i, 1141 per_replica_batch["ragged"]).values, 1142 6) 1143 self.assertAllEqual( 1144 distribute_utils.select_replica( 1145 i, per_replica_batch["ragged"]).to_tensor(), 1146 distribute_utils.select_replica(i, per_replica_batch["dense"])) 1147 self.assertLen( 1148 distribute_utils.select_replica(i, 1149 per_replica_batch["sparse"]).indices, 1150 6) 1151 self.assertAllEqual( 1152 sparse_ops.sparse_tensor_to_dense( 1153 distribute_utils.select_replica(i, per_replica_batch["sparse"])), 1154 distribute_utils.select_replica(i, per_replica_batch["dense"])) 1155 # Iterate through all the batches and sum them up. 1156 def sum_batch(per_replica_features): 1157 """Sums the `PerReplica` values in the `per_replica_features` map.""" 1158 1159 def map_fn(per_replica_values): 1160 per_replica_sums = distribution.run( 1161 (lambda x: math_ops.reduce_sum(x.values)) if all( 1162 map(sparse_tensor.is_sparse, per_replica_values.values)) else 1163 math_ops.reduce_sum, (per_replica_values,)) 1164 return distribution.reduce( 1165 reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) 1166 1167 return nest.map_structure(map_fn, per_replica_features) 1168 1169 def _reduce(state, batch): 1170 sums = sum_batch(batch) 1171 return {name: value + sums[name] for name, value in state.items()} 1172 1173 def sum_for_loop(dataset): 1174 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1175 for batch in dataset: 1176 sums = _reduce(sums, batch) 1177 return sums 1178 1179 def sum_while_loop(iterator, reduce_fn): 1180 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1181 while True: 1182 try: 1183 sums = reduce_fn(sums, iterator) 1184 except (StopIteration, errors.OutOfRangeError): 1185 return sums 1186 1187 while_sums = sum_while_loop( 1188 iter(dataset), 1189 defun(lambda state, iterator: _reduce(state, next(iterator)))) 1190 self.assertAllEqual( 1191 nest.flatten(while_sums), 1192 # When there's no partial batch, the sum is smaller. 1193 [200. if drop_remainder else 310.] * 3) 1194 for_sums = defun(sum_for_loop)(dataset) 1195 # For loops always call get next as optional inside tf functions, so we 1196 # expect 310 here when using an input function (as there are 5 batches of 1197 # size 4 round robined over 2 replicas. 1198 expected_for_sum = 200. 1199 if (not drop_remainder or 1200 (defun_type == "tf_function" and input_type == "input_fn")): 1201 expected_for_sum = 310. 1202 self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3) 1203 1204 @combinations.generate( 1205 combinations.combine( 1206 mode=["eager"], 1207 distribution=[ 1208 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1209 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1210 strategy_combinations.one_device_strategy, 1211 strategy_combinations.mirrored_strategy_with_one_cpu 1212 ], 1213 input_type=["dataset", "input_fn"], 1214 drop_remainder=[False, True], 1215 tensor_type=["sparse", "ragged"], 1216 enable_get_next_as_optional=[True, False])) 1217 def testRaggedSparseGetNextAsOptional(self, distribution, input_type, 1218 drop_remainder, tensor_type, 1219 enable_get_next_as_optional): 1220 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1221 if not tf2.enabled(): 1222 self.skipTest("Only V2 is supported.") 1223 1224 distribution.extended.experimental_enable_get_next_as_optional = ( 1225 enable_get_next_as_optional) 1226 global_batch_size = 8 1227 1228 def dataset_fn(ctx=None): 1229 ctx = ctx or distribute_lib.InputContext() 1230 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1231 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1232 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1233 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1234 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1235 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1236 tensor_type: (ragged_tensor if tensor_type == "ragged" else 1237 ragged_tensor.to_sparse()), 1238 }) 1239 dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1240 return dataset.batch(batch_size, drop_remainder=drop_remainder) 1241 1242 if input_type == "dataset": 1243 ds = distribution.experimental_distribute_dataset( 1244 dataset_fn(distribute_lib.InputContext())) 1245 else: 1246 ds = distribution.distribute_datasets_from_function(dataset_fn) 1247 iterator = iter(ds) 1248 1249 self.assertEqual(iterator._enable_get_next_as_optional, 1250 (not drop_remainder) and enable_get_next_as_optional) 1251 1252 @combinations.generate( 1253 combinations.combine( 1254 tf_api_version=2, 1255 mode=["eager"], 1256 distribution=[ 1257 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1258 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1259 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1260 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1261 strategy_combinations.multi_worker_mirrored_2x2_gpu, 1262 ], 1263 input_type=["dataset", "input_fn"], 1264 drop_remainder=[False, True], 1265 )) 1266 def testRaggedSparseGetNextAsOptionalInLoop(self, distribution, input_type, 1267 drop_remainder): 1268 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1269 global_batch_size = 8 1270 1271 def dataset_fn(ctx=None): 1272 ctx = ctx or distribute_lib.InputContext() 1273 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1274 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1275 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1276 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1277 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1278 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1279 "dense": ragged_tensor.to_tensor(), 1280 "ragged": ragged_tensor, 1281 "sparse": ragged_tensor.to_sparse(), 1282 }) 1283 dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) 1284 return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1285 1286 if input_type == "dataset": 1287 ds = distribution.experimental_distribute_dataset( 1288 dataset_fn(distribute_lib.InputContext())) 1289 else: 1290 ds = distribution.distribute_datasets_from_function(dataset_fn) 1291 1292 # Iterate through all the batches and sum them up. 1293 def sum_batch(per_replica_features): 1294 """Sums the `PerReplica` values in the `per_replica_features` map.""" 1295 1296 def map_fn(per_replica_values): 1297 1298 def _sum(value): 1299 if sparse_tensor.is_sparse(value): 1300 return math_ops.reduce_sum(value.values) 1301 else: 1302 return math_ops.reduce_sum(value) 1303 1304 per_replica_sums = distribution.run(_sum, args=(per_replica_values,)) 1305 return distribution.reduce( 1306 reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) 1307 1308 return nest.map_structure(map_fn, per_replica_features) 1309 1310 def _reduce(state, batch): 1311 sums = sum_batch(batch) 1312 return {name: value + sums[name] for name, value in state.items()} 1313 1314 def sum_while_loop(ds): 1315 iterator = iter(ds) 1316 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1317 try_next = constant_op.constant(True) 1318 1319 while try_next: 1320 opt_iterate = iterator.get_next_as_optional() 1321 if opt_iterate.has_value(): 1322 sums = _reduce(sums, opt_iterate.get_value()) 1323 else: 1324 try_next = False 1325 return sums 1326 1327 sums = def_function.function(sum_while_loop)(ds) 1328 # For loops always call get next as optional inside tf functions, so we 1329 # expect 310 here when using an input function (as there are 5 batches of 1330 # size 4 round robined over 2 replicas. 1331 expected_for_sum = 200. 1332 if not drop_remainder or input_type == "input_fn": 1333 expected_for_sum = 310. 1334 self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3) 1335 1336 @combinations.generate( 1337 combinations.combine( 1338 mode=["eager"], 1339 input_type=["dataset"], 1340 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1341 iteration_type=["get_next", "for_loop"], 1342 distribution=[ 1343 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1344 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1345 ])) 1346 def testMWMSPartialBatch(self, input_type, api_type, iteration_type, 1347 distribution): 1348 # Test case: 2 workers, 1 replica each. 1349 # This test simulates the sharded behavior when we have two files each with 1350 # 12 elements and a global batch size of 8. When we consider the dataset in 1351 # aggregate (non-distributed), there are 24 elements divided into 3 batches 1352 # of size 8. Hence, the correct distributed behavior is for each replica to 1353 # see sub-batches of size 4, over three steps. 1354 def dataset_fn(ctx): 1355 del ctx 1356 dataset = dataset_ops.Dataset.range(12).batch(8) 1357 1358 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1359 # `dataset` defines the per-worker dataset and will not be further 1360 # sharded. Each worker will see a dataset that is 1361 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1362 options = options_lib.Options() 1363 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 1364 dataset = dataset.with_options(options) 1365 return dataset 1366 1367 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1368 1369 # Actual devices don't matter in this test as long as there is 1 local 1370 # replica. 1371 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1372 1373 # Each test runs individually on each worker, so we compare the 1374 # values on each worker. Each worker should rebatch its dataset into 1375 # smaller batches of size 4. 1376 expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]] 1377 self._test_input_iteration( 1378 input_type, 1379 api_type, 1380 iteration_type, 1381 dataset, 1382 worker_device_pairs, 1383 expected_values, 1384 distribution, 1385 num_replicas_in_sync=distribution.num_replicas_in_sync, 1386 input_context=distribution.extended._make_input_context()) 1387 1388 @combinations.generate( 1389 combinations.combine( 1390 mode=["eager"], 1391 input_type=["dataset"], 1392 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1393 iteration_type=["get_next", "for_loop"], 1394 distribution=[ 1395 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1396 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1397 ])) 1398 def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type, 1399 iteration_type, distribution): 1400 # Test case: 2 workers, 1 replica each. 1401 # This test simulates the sharded behavior when we have two files each with 1402 # 12 elements and a global batch size of 8. When we consider the dataset in 1403 # aggregate (non-distributed), there are 24 elements divided into 3 batches 1404 # of size 8. Hence, the correct distributed behavior is for each replica to 1405 # see sub-batches of size 4, over three steps. However, when we create a 1406 # DistributedDataset and cannot statically infer the intended global batch 1407 # size (e.g. if the user does not use a batching dataset), each worker will 1408 # rebatch based on the dynamic batch size of the data encountered, even when 1409 # it encounters partial batches. The last per-worker partial batch (size 4) 1410 # ends up being split into two replicas, resulting in 4 steps in total, of 1411 # (global) batch sizes 8, 8, 4, 4. 1412 def dataset_fn(ctx): 1413 del ctx 1414 # The following dataset is equivalent to 1415 # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset. 1416 # This causes DistributedDataset to use LegacyRebatch instead. 1417 batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4]) 1418 offsets = dataset_ops.Dataset.from_tensor_slices([0, 8]) 1419 dataset = dataset_ops.Dataset.zip((offsets, batch_sizes)) 1420 1421 def map_fn(offset, batch_size): 1422 return math_ops.range(offset, offset + batch_size) 1423 1424 dataset = dataset.map(map_fn) 1425 1426 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1427 # `dataset` defines the per-worker dataset and will not be further 1428 # sharded. Each worker will see a dataset that is equivalent to 1429 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1430 options = options_lib.Options() 1431 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 1432 dataset = dataset.with_options(options) 1433 return dataset 1434 1435 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1436 1437 # Actual devices don't matter in this test as long as the number of global 1438 # replicas is 2. 1439 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1440 1441 # Each test runs individually on each worker, so we compare the 1442 # values on each worker. Each worker should rebatch its dataset into 1443 # smaller batches of size 4. 1444 expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]] 1445 self._test_input_iteration( 1446 input_type, 1447 api_type, 1448 iteration_type, 1449 dataset, 1450 worker_device_pairs, 1451 expected_values, 1452 distribution, 1453 num_replicas_in_sync=distribution.num_replicas_in_sync, 1454 input_context=distribution.extended._make_input_context()) 1455 1456 @combinations.generate( 1457 combinations.combine( 1458 mode=["eager"], 1459 input_type=["dataset"], 1460 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1461 iteration_type=["get_next", "for_loop"], 1462 distribution=[ 1463 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1464 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1465 ], 1466 auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA])) 1467 def testMWMSWithDataSharding(self, input_type, api_type, iteration_type, 1468 distribution, auto_shard_policy): 1469 # Test case: 2 workers, 1 replica each. 1470 # This test simulates the sharded behavior the dataset is sharded by data 1471 # and the batch size is indivisible by the number of replicas. This checks 1472 # that the elements are as expected and the batch size across all workers 1473 # adds up to 3. This test will only pass if the autoshard rewrite rewrites 1474 # RebatchDatasetV2 to legacy RebatchDataset when sharding by data. 1475 def dataset_fn(ctx): 1476 del ctx 1477 dataset = dataset_ops.Dataset.range(8).batch(3) 1478 1479 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1480 # `dataset` defines the per-worker dataset and will not be further 1481 # sharded. Each worker will see a dataset that is 1482 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1483 options = options_lib.Options() 1484 options.experimental_distribute.auto_shard_policy = auto_shard_policy 1485 dataset = dataset.with_options(options) 1486 return dataset 1487 1488 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1489 1490 # Actual devices don't matter in this test as long as there is 1 local 1491 # replica. 1492 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1493 1494 # Each test runs individually on each worker, so we compare the 1495 # values on each worker. We expect each worker to see different shards of 1496 # data. 1497 cr = distribution.cluster_resolver 1498 worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type, 1499 cr.task_id) 1500 1501 if worker_id == 0: 1502 expected_values = [[[0, 1]], [[3, 4]], [[6]]] 1503 elif worker_id == 1: 1504 expected_values = [[[2]], [[5]], [[7]]] 1505 1506 self._test_input_iteration( 1507 input_type, 1508 api_type, 1509 iteration_type, 1510 dataset, 1511 worker_device_pairs, 1512 expected_values, 1513 distribution, 1514 num_replicas_in_sync=distribution.num_replicas_in_sync, 1515 input_context=distribution.extended._make_input_context()) 1516 1517 1518@framework_test_util.with_eager_op_as_function 1519class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, 1520 parameterized.TestCase): 1521 """Tests for PER_WORKER and PER_REPLICA's InputOptions variants.""" 1522 1523 def setUp(self): 1524 context._reset_context() 1525 strategy_combinations.set_virtual_cpus_to_at_least(3) 1526 super(DistributedIteratorPerDeviceTest, self).setUp() 1527 1528 @combinations.generate( 1529 combinations.combine( 1530 input_options=[ 1531 distribute_lib.InputOptions( 1532 experimental_place_dataset_on_device=False, 1533 experimental_fetch_to_device=True, 1534 experimental_replication_mode=distribute_lib 1535 .InputReplicationMode.PER_WORKER), 1536 distribute_lib.InputOptions( 1537 experimental_place_dataset_on_device=False, 1538 experimental_fetch_to_device=True, 1539 experimental_replication_mode=distribute_lib 1540 .InputReplicationMode.PER_REPLICA), 1541 ], 1542 mode=["eager"], 1543 distribution=[ 1544 strategy_combinations.mirrored_strategy_with_two_gpus, 1545 strategy_combinations 1546 .mirrored_strategy_with_two_gpus_no_merge_call, 1547 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1548 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1549 ])) 1550 def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution, 1551 input_options): 1552 1553 def dataset_fn(input_context): # pylint: disable=[unused-argument] 1554 return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]) 1555 1556 ds = distribution.experimental_distribute_datasets_from_function( 1557 dataset_fn, input_options) 1558 1559 for x in ds: 1560 assert x.values[0].device == distribution.extended.worker_devices[0] 1561 assert x.values[0].backing_device == distribution.extended.worker_devices[ 1562 0] 1563 assert x.values[1].device == distribution.extended.worker_devices[1] 1564 assert x.values[1].backing_device == distribution.extended.worker_devices[ 1565 1] 1566 1567 @combinations.generate( 1568 combinations.combine( 1569 distribution=[ 1570 strategy_combinations.mirrored_strategy_with_two_gpus, 1571 strategy_combinations 1572 .mirrored_strategy_with_two_gpus_no_merge_call, 1573 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1574 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1575 ], 1576 input_options=[ 1577 distribute_lib.InputOptions( 1578 experimental_place_dataset_on_device=False, 1579 experimental_fetch_to_device=False, 1580 experimental_replication_mode=distribute_lib 1581 .InputReplicationMode.PER_WORKER) 1582 ], 1583 mode=["eager"], 1584 )) 1585 def testDevicePlacementForPerWorkerValuesWithoutPrefetch( 1586 self, distribution, input_options): 1587 1588 def dataset_fn(input_context): 1589 return dataset_ops.Dataset.from_tensor_slices( 1590 np.full(4, input_context.input_pipeline_id)) 1591 1592 ds = distribution.experimental_distribute_datasets_from_function( 1593 dataset_fn, input_options) 1594 1595 for x in ds: 1596 x = distribution.run(lambda inputs: inputs, args=(x,)) 1597 assert x.values[ 1598 0].device == "/job:localhost/replica:0/task:0/device:CPU:0" 1599 assert x.values[ 1600 0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" 1601 assert x.values[ 1602 1].device == "/job:localhost/replica:0/task:0/device:CPU:0" 1603 assert x.values[ 1604 1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" 1605 1606 @combinations.generate( 1607 combinations.combine( 1608 input_options=[ 1609 distribute_lib.InputOptions( 1610 experimental_place_dataset_on_device=True, 1611 experimental_fetch_to_device=False, 1612 experimental_replication_mode=distribute_lib 1613 .InputReplicationMode.PER_WORKER), 1614 distribute_lib.InputOptions( 1615 experimental_place_dataset_on_device=True, 1616 experimental_fetch_to_device=True, 1617 experimental_replication_mode=distribute_lib 1618 .InputReplicationMode.PER_REPLICA) 1619 ], 1620 mode=["eager"], 1621 distribution=[ 1622 strategy_combinations.mirrored_strategy_with_two_gpus, 1623 strategy_combinations 1624 .mirrored_strategy_with_two_gpus_no_merge_call, 1625 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1626 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1627 ])) 1628 def testDevicePlacementForInvalidCombinations(self, distribution, 1629 input_options): 1630 1631 def dataset_fn(input_context): 1632 return dataset_ops.Dataset.from_tensor_slices( 1633 np.full(4, input_context.input_pipeline_id)) 1634 1635 with self.assertRaises(ValueError): 1636 distribution.experimental_distribute_datasets_from_function( 1637 dataset_fn, input_options) 1638 1639 @combinations.generate( 1640 combinations.combine( 1641 input_options=[ 1642 distribute_lib.InputOptions( 1643 experimental_place_dataset_on_device=False, 1644 experimental_fetch_to_device=False, 1645 experimental_per_replica_buffer_size=2), 1646 distribute_lib.InputOptions( 1647 experimental_place_dataset_on_device=False, 1648 experimental_fetch_to_device=True, 1649 experimental_per_replica_buffer_size=2), 1650 ], 1651 mode=["eager"], 1652 distribution=[ 1653 strategy_combinations.mirrored_strategy_with_two_gpus, 1654 strategy_combinations 1655 .mirrored_strategy_with_two_gpus_no_merge_call, 1656 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1657 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1658 ])) 1659 def testPrefetchBufferSizeInputOptions(self, distribution, input_options): 1660 1661 def dataset_fn(input_context): 1662 return dataset_ops.Dataset.from_tensor_slices( 1663 np.arange(1, 11).reshape( 1664 (2, 5)) * (input_context.input_pipeline_id + 1)) 1665 1666 ds = distribution.experimental_distribute_datasets_from_function( 1667 dataset_fn, input_options) 1668 1669 # validating the values 1670 x = next(iter(ds)) 1671 assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) 1672 assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) 1673 1674 @combinations.generate( 1675 combinations.combine( 1676 input_options=[ 1677 distribute_lib.InputOptions( 1678 experimental_place_dataset_on_device=False, 1679 experimental_fetch_to_device=False, 1680 experimental_replication_mode=distribute_lib 1681 .InputReplicationMode.PER_WORKER), 1682 distribute_lib.InputOptions( 1683 experimental_place_dataset_on_device=False, 1684 experimental_fetch_to_device=True, 1685 experimental_replication_mode=distribute_lib 1686 .InputReplicationMode.PER_WORKER), 1687 ], 1688 mode=["eager"], 1689 distribution=[ 1690 strategy_combinations.mirrored_strategy_with_two_gpus, 1691 strategy_combinations 1692 .mirrored_strategy_with_two_gpus_no_merge_call, 1693 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1694 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1695 ])) 1696 def testOutputValuesForPerWorkerInputOptions(self, distribution, 1697 input_options): 1698 1699 def dataset_fn(input_context): 1700 return dataset_ops.Dataset.from_tensor_slices( 1701 np.arange(1, 11).reshape( 1702 (2, 5)) * (input_context.input_pipeline_id + 1)) 1703 1704 ds = distribution.experimental_distribute_datasets_from_function( 1705 dataset_fn, input_options) 1706 1707 # validating the values 1708 x = next(iter(ds)) 1709 assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) 1710 assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) 1711 1712 @combinations.generate( 1713 combinations.combine( 1714 input_options=[ 1715 distribute_lib.InputOptions( 1716 experimental_place_dataset_on_device=True, 1717 experimental_fetch_to_device=False, 1718 experimental_replication_mode=distribute_lib 1719 .InputReplicationMode.PER_REPLICA), 1720 distribute_lib.InputOptions( 1721 experimental_place_dataset_on_device=False, 1722 experimental_fetch_to_device=False, 1723 experimental_replication_mode=distribute_lib 1724 .InputReplicationMode.PER_REPLICA), 1725 distribute_lib.InputOptions( 1726 experimental_place_dataset_on_device=False, 1727 experimental_fetch_to_device=True, 1728 experimental_replication_mode=distribute_lib 1729 .InputReplicationMode.PER_REPLICA), 1730 ], 1731 mode=["eager"], 1732 distribution=[ 1733 strategy_combinations.mirrored_strategy_with_two_gpus, 1734 strategy_combinations 1735 .mirrored_strategy_with_two_gpus_no_merge_call, 1736 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1737 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1738 ])) 1739 def testOutputValuesForPerReplicaInputOptions(self, distribution, 1740 input_options): 1741 1742 def dataset_fn(input_context): 1743 return dataset_ops.Dataset.from_tensor_slices( 1744 np.arange(1, 10) * (input_context.input_pipeline_id + 1)) 1745 1746 ds = distribution.experimental_distribute_datasets_from_function( 1747 dataset_fn, input_options) 1748 expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) 1749 for i, x in enumerate(ds): 1750 # validating the values 1751 assert x.values[0].numpy() == expected[i] 1752 assert x.values[1].numpy() == expected[i] * 2 1753 loop_num = i 1754 assert loop_num == len(expected) - 1 1755 1756 1757class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase, 1758 parameterized.TestCase): 1759 """Tests for distributed iterators which read from tf.data service.""" 1760 1761 def setUp(self): 1762 super(DistributedIteratorTfDataServiceTest, self).setUp() 1763 self.num_workers = 3 1764 if combinations.in_main_process(): 1765 self.dispatcher = server_lib.DispatchServer() 1766 self.workers = [] 1767 for _ in range(self.num_workers): 1768 self.workers.append( 1769 server_lib.WorkerServer( 1770 server_lib.WorkerConfig( 1771 dispatcher_address=self.dispatcher.target.split("://")[1], 1772 heartbeat_interval_ms=100, 1773 dispatcher_timeout_ms=1000))) 1774 combinations.env().tf_data_service_dispatcher = self.dispatcher.target 1775 1776 @combinations.generate( 1777 combinations.combine( 1778 mode=["eager"], 1779 distribution=[ 1780 strategy_combinations.one_device_strategy, 1781 strategy_combinations.mirrored_strategy_with_one_cpu, 1782 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1783 strategy_combinations.tpu_strategy, 1784 strategy_combinations.central_storage_strategy_with_two_gpus, 1785 strategy_combinations.multi_worker_mirrored_2x2_gpu, 1786 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 1787 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1788 ])) 1789 def testTfDataService(self, distribution): 1790 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1791 input_workers = input_lib.InputWorkers(worker_device_pairs) 1792 1793 dataset = dataset_ops.Dataset.range(1, 50) 1794 dataset = dataset.apply( 1795 data_service_ops._distribute( 1796 processing_mode=data_service_ops.ShardingPolicy.OFF, 1797 service=combinations.env().tf_data_service_dispatcher, 1798 job_name="foo")) 1799 1800 dist_dataset = input_util.get_distributed_dataset(dataset, input_workers, 1801 distribution) 1802 iterator = iter(dist_dataset) 1803 results = [] 1804 for element in iterator: 1805 local_results = distribution.experimental_local_results(element) 1806 for result in local_results: 1807 # input_lib.distributed_dataset may add extra '0' elements to pad 1808 # per-replica results. 1809 if result.numpy() != 0: 1810 results.append(result.numpy()) 1811 self.assertNotEmpty(results) 1812 gathered = distribution.gather(constant_op.constant(results), axis=0) 1813 self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered) 1814 1815 histogram_proto = ( 1816 input_lib._distributed_dataset_initialization_time_milliseconds 1817 .get_cell(distribution.__class__.__name__, "1").value()) 1818 self.assertGreater(histogram_proto.num, 0.0) 1819 1820 @combinations.generate( 1821 combinations.combine( 1822 mode=["eager"], 1823 distribution=[ 1824 strategy_combinations.one_device_strategy, 1825 strategy_combinations.mirrored_strategy_with_one_cpu, 1826 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1827 strategy_combinations.tpu_strategy, 1828 strategy_combinations.central_storage_strategy_with_two_gpus, 1829 strategy_combinations.multi_worker_mirrored_2x2_gpu, 1830 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 1831 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1832 ])) 1833 def testDistributeDatasetFromFunction(self, distribution): 1834 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1835 input_workers = input_lib.InputWorkers(worker_device_pairs) 1836 input_contexts = [] 1837 num_workers = input_workers.num_workers 1838 for i in range(num_workers): 1839 input_contexts.append(distribute_lib.InputContext( 1840 num_input_pipelines=num_workers, 1841 input_pipeline_id=i, 1842 num_replicas_in_sync=num_workers)) 1843 1844 dataset = dataset_ops.Dataset.range(1, 50) 1845 dataset_id = data_service_ops.register_dataset( 1846 service=combinations.env().tf_data_service_dispatcher, 1847 dataset=dataset) 1848 1849 def dataset_fn(input_context): 1850 del input_context 1851 return data_service_ops.from_dataset_id( 1852 processing_mode=data_service_ops.ShardingPolicy.OFF, 1853 service=combinations.env().tf_data_service_dispatcher, 1854 dataset_id=dataset_id, 1855 element_spec=dataset.element_spec, 1856 job_name="shared_job") 1857 1858 dist_dataset = input_util.get_distributed_datasets_from_function( 1859 dataset_fn, input_workers, input_contexts, distribution) 1860 1861 iterator = iter(dist_dataset) 1862 results = [] 1863 for element in iterator: 1864 local_results = distribution.experimental_local_results(element) 1865 for result in local_results: 1866 # input_lib.distributed_dataset may add extra '0' elements to pad 1867 # per-replica results. 1868 if result.numpy() != 0: 1869 results.append(result.numpy()) 1870 self.assertNotEmpty(results) 1871 gathered = distribution.gather(constant_op.constant(results), axis=0) 1872 self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered) 1873 histogram_proto = ( 1874 input_lib 1875 ._distributed_dataset_from_function_initialization_time_milliseconds 1876 .get_cell(distribution.__class__.__name__, "1").value()) 1877 self.assertGreater(histogram_proto.num, 0.0) 1878 1879 @combinations.generate( 1880 combinations.combine( 1881 mode=["eager"], 1882 distribution=[ 1883 strategy_combinations.one_device_strategy, 1884 strategy_combinations.mirrored_strategy_with_one_cpu, 1885 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1886 strategy_combinations.mirrored_strategy_with_two_gpus, 1887 strategy_combinations.tpu_strategy, 1888 strategy_combinations.central_storage_strategy_with_two_gpus, 1889 strategy_combinations.multi_worker_mirrored_2x2_gpu, 1890 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 1891 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1892 ])) 1893 def testDistributeDatasetFromFunctionNested(self, distribution): 1894 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1895 input_workers = input_lib.InputWorkers(worker_device_pairs) 1896 input_contexts = [] 1897 num_workers = input_workers.num_workers 1898 for i in range(num_workers): 1899 input_contexts.append( 1900 distribute_lib.InputContext( 1901 num_input_pipelines=num_workers, 1902 input_pipeline_id=i, 1903 num_replicas_in_sync=num_workers)) 1904 1905 class InnerType(extension_type.ExtensionType): 1906 tensor: ops.Tensor 1907 1908 class OuterType(extension_type.ExtensionType): 1909 inner: InnerType 1910 1911 def dataset_fn(input_context): 1912 del input_context 1913 1914 def data_fn(batch_id) -> OuterType: 1915 del batch_id 1916 1917 return OuterType( 1918 inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]]))) 1919 1920 return dataset_ops.Dataset.range(1, 10).map(data_fn) 1921 1922 dist_dataset = input_util.get_distributed_datasets_from_function( 1923 dataset_fn, input_workers, input_contexts, distribution) 1924 1925 iterator = iter(dist_dataset) 1926 results = [] 1927 for element in iterator: 1928 local_results = distribution.experimental_local_results(element) 1929 for result in local_results: 1930 results.append(result) 1931 1932 expect_component = OuterType( 1933 inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]]))) 1934 self.assertCountEqual( 1935 num_workers * [expect_component for _ in range(1, 10)], results) 1936 1937if __name__ == "__main__": 1938 test_util.main() 1939