xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/input_lib_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the input_lib library."""
16
17import collections
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python import tf2
23from tensorflow.python.data.experimental.ops import data_service_ops
24from tensorflow.python.data.experimental.service import server_lib
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import options as options_lib
27from tensorflow.python.data.ops.options import AutoShardPolicy
28from tensorflow.python.distribute import combinations
29from tensorflow.python.distribute import device_util
30from tensorflow.python.distribute import distribute_lib
31from tensorflow.python.distribute import distribute_utils
32from tensorflow.python.distribute import input_lib
33from tensorflow.python.distribute import input_util
34from tensorflow.python.distribute import multi_worker_util
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import strategy_combinations
37from tensorflow.python.distribute import test_util
38from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.eager import test
42from tensorflow.python.framework import composite_tensor
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import extension_type
47from tensorflow.python.framework import ops
48from tensorflow.python.framework import sparse_tensor
49from tensorflow.python.framework import test_util as framework_test_util
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import sparse_ops
54from tensorflow.python.ops import variables
55from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
56from tensorflow.python.util import nest
57
58
59class DistributedIteratorTestBase(test.TestCase):
60
61  # The passed input_context is to create a sharded dataset in between-graph
62  # case.
63  # TODO(yuefengz): rewrite the following method to make it less DRY.
64  def _wrap_iterator(self,
65                     input_type,
66                     dataset_or_input_fn,
67                     input_workers,
68                     devices,
69                     num_replicas_in_sync,
70                     strategy,
71                     input_context=None):
72    # The `input_context` passed in is to shard dataset for
73    # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where
74    # multiple InputContexts are needed.
75    if input_type == "input_fn":
76      self.assertIsNone(
77          input_context,
78          msg=("`The input_context` arg is only used to shard dataset in "
79               "`MultiWorkerMirroredStrategy` when the input type is dataset."))
80
81      input_contexts = []
82      for i in range(input_workers.num_workers):
83        input_contexts.append(
84            distribute_lib.InputContext(
85                # Note: `input_workers.num_workers` is always 1 in between-graph
86                # case.
87                num_input_pipelines=input_workers.num_workers,
88                input_pipeline_id=i,
89                num_replicas_in_sync=len(devices)))
90
91      iterator = input_lib_v1.InputFunctionIterator(dataset_or_input_fn,
92                                                    input_workers,
93                                                    input_contexts, strategy)
94    else:
95      iterator = input_lib_v1.DatasetIterator(
96          dataset_or_input_fn,
97          input_workers,
98          strategy,
99          num_replicas_in_sync=num_replicas_in_sync,
100          input_context=input_context)
101    return iterator
102
103  def _wrap_dataset(self,
104                    input_type,
105                    dataset,
106                    input_workers,
107                    num_replicas_in_sync,
108                    strategy,
109                    input_context=None):
110    if input_type == "dataset":
111      if tf2.enabled():
112        return input_lib.DistributedDataset(
113            input_workers,
114            strategy,
115            dataset,
116            num_replicas_in_sync=num_replicas_in_sync,
117            input_context=input_context)
118      else:
119        return input_lib_v1.DistributedDatasetV1(
120            dataset,
121            input_workers,
122            strategy,
123            num_replicas_in_sync=num_replicas_in_sync,
124            input_context=input_context)
125    else:
126      return strategy.distribute_datasets_from_function(dataset)
127
128  def _assert_iterator_values(self,
129                              iterator,
130                              expected_values,
131                              evaluate_fn,
132                              devices,
133                              enable_get_next_as_optional=False):
134    actual_values = []
135    for _ in range(len(expected_values)):
136      if enable_get_next_as_optional:
137        next_element = iterator.get_next_as_optional().get_value()
138      else:
139        next_element = iterator.get_next()
140      computed_value = evaluate_fn([
141          distribute_utils.select_replica(r, next_element)
142          for r in range(len(devices))
143      ])
144      actual_values.append(computed_value)
145    for expected_value, actual_value in zip(expected_values, actual_values):
146      for expected, actual in zip(expected_value, actual_value):
147        self.assertAllEqual(expected, actual)
148
149  def _assert_dataset_values_for_loop(self, dataset, expected_values,
150                                      evaluate_fn, devices):
151    actual_values = []
152    for x in dataset:
153      computed_value = self.evaluate(
154          [distribute_utils.select_replica(r, x) for r in range(len(devices))])
155      actual_values.append(computed_value)
156    for expected_value, actual_value in zip(expected_values, actual_values):
157      for expected, actual in zip(expected_value, actual_value):
158        self.assertAllEqual(expected, actual)
159
160  def _test_input_iteration(self,
161                            input_type,
162                            api_type,
163                            iteration_type,
164                            dataset_or_input_fn,
165                            worker_device_pairs,
166                            expected_values,
167                            strategy,
168                            sess=None,
169                            num_replicas_in_sync=None,
170                            input_context=None):
171    if iteration_type == "for_loop" and not context.executing_eagerly():
172      self.skipTest("unsupported test combination.")
173
174    if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
175      self.skipTest("unsupported test combination.")
176
177    if api_type == "wrap_into_iterator" and input_type == "input_fn":
178      self.skipTest("unsupported test combination.")
179
180    devices = nest.flatten([ds for _, ds in worker_device_pairs])
181    input_workers = input_lib.InputWorkers(worker_device_pairs)
182
183    if api_type == "wrap_into_iterator":
184      iterator = self._wrap_iterator(
185          input_type,
186          dataset_or_input_fn,
187          input_workers,
188          devices,
189          num_replicas_in_sync,
190          strategy,
191          input_context=input_context)
192    else:
193      # wrapping into a dataset:
194      dataset = self._wrap_dataset(
195          input_type,
196          dataset_or_input_fn,
197          input_workers,
198          num_replicas_in_sync,
199          strategy,
200          input_context=input_context)
201
202      if ops.executing_eagerly_outside_functions():
203        iterator = iter(dataset)
204      else:
205        if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
206          iterator = dataset.make_initializable_iterator()
207        else:
208          self.skipTest("unsupported test combination")
209
210    if isinstance(iterator, composite_tensor.CompositeTensor):
211      nest.assert_same_structure(
212          iterator, iterator._type_spec, expand_composites=True)
213
214    if iteration_type == "get_next":
215      evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
216      if not ops.executing_eagerly_outside_functions():
217        evaluate(control_flow_ops.group(iterator.initializer))
218
219      def test_get_next(iterator):
220        self._assert_iterator_values(iterator, expected_values, evaluate,
221                                     devices)
222
223        with self.assertRaises(errors.OutOfRangeError):
224          self._assert_iterator_values(iterator, expected_values, evaluate,
225                                       devices)
226
227        # After re-initializing the iterator, should be able to iterate again.
228        if not ops.executing_eagerly_outside_functions():
229          evaluate(control_flow_ops.group(iterator.initializer))
230        else:
231          if api_type == "wrap_into_iterator":
232            self.skipTest("unsupported test combination")
233          else:
234            iterator = iter(dataset)
235
236        self._assert_iterator_values(iterator, expected_values, evaluate,
237                                     devices)
238
239      def test_get_next_as_optional(iterator):
240        self._assert_iterator_values(
241            iterator,
242            expected_values,
243            evaluate,
244            devices,
245            enable_get_next_as_optional=True)
246
247        next_element = iterator.get_next_as_optional()
248        self.assertFalse(self.evaluate(next_element.has_value()))
249        with self.assertRaises(errors.InvalidArgumentError):
250          self._assert_iterator_values(
251              iterator, [0],
252              evaluate,
253              devices,
254              enable_get_next_as_optional=True)
255
256      test_get_next(iterator)
257
258      # re-initializing the iterator
259      if not tf2.enabled():
260        # TODO(yuefengz): we should split this function.
261        return
262      else:
263        if api_type == "wrap_into_iterator":
264          return
265        else:
266          iterator = iter(dataset)
267
268      test_get_next_as_optional(iterator)
269
270    if iteration_type == "for_loop" and context.executing_eagerly():
271      self._assert_dataset_values_for_loop(dataset, expected_values,
272                                           self.evaluate, devices)
273
274  def _create_dataset_or_input_fn(self, input_type, input_fn):
275    if input_type == "input_fn":
276      return input_fn
277    else:
278      return input_fn(distribute_lib.InputContext())
279
280
281class DistributedIteratorTest(DistributedIteratorTestBase,
282                              parameterized.TestCase):
283
284  @combinations.generate(
285      combinations.combine(
286          mode=["eager"],
287          distribution=[
288              strategy_combinations.mirrored_strategy_with_gpu_and_cpu
289          ]))
290  def testMultiDeviceIterInitialize(self, distribution):
291    if tf2.enabled():
292      self.skipTest("Only V1 is supported.")
293    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
294                                              "/device:CPU:0"])]
295    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
296
297    input_workers = input_lib.InputWorkers(worker_device_pairs)
298
299    dist_dataset = input_util.get_distributed_dataset(
300        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
301
302    iterator = dataset_ops.make_one_shot_iterator(dist_dataset)
303
304    @def_function.function
305    def init_func_for_iter():
306      self.evaluate(iterator.initializer)
307
308    init_func_for_iter()
309
310  @combinations.generate(
311      combinations.combine(
312          mode=["graph", "eager"],
313          input_type=["input_fn", "dataset"],
314          api_type=["wrap_into_iterator", "wrap_into_dataset"],
315          iteration_type=["get_next", "for_loop"],
316          distribution=[
317              strategy_combinations.one_device_strategy,
318              strategy_combinations.mirrored_strategy_with_one_cpu,
319          ],
320          enable_get_next_as_optional=[True, False]))
321  def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
322                       enable_get_next_as_optional):
323    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
324    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
325    dataset_or_input_fn = self._create_dataset_or_input_fn(
326        input_type, dataset_fn)
327
328    expected_values = [[i] for i in range(10)]
329
330    distribution.extended.experimental_enable_get_next_as_optional = (
331        enable_get_next_as_optional)
332    self._test_input_iteration(input_type, api_type, iteration_type,
333                               dataset_or_input_fn, worker_device_pairs,
334                               expected_values, distribution)
335
336  @combinations.generate(
337      combinations.combine(
338          mode=["eager"],
339          input_type=["input_fn", "dataset"],
340          api_type=["wrap_into_dataset"],
341          iteration_type=["get_next", "for_loop"],
342          distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu],
343          enable_get_next_as_optional=[True, False]))
344  def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type,
345                                  distribution, enable_get_next_as_optional):
346    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
347    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
348    dataset_or_input_fn = self._create_dataset_or_input_fn(
349        input_type, dataset_fn)
350
351    expected_values = [[i] for i in range(10)]
352
353    distribution.extended.experimental_enable_get_next_as_optional = (
354        enable_get_next_as_optional)
355    self._test_input_iteration(input_type, api_type, iteration_type,
356                               dataset_or_input_fn, worker_device_pairs,
357                               expected_values, distribution)
358
359  @combinations.generate(
360      combinations.combine(
361          mode=["graph", "eager"],
362          input_type=["input_fn", "dataset"],
363          api_type=["wrap_into_iterator", "wrap_into_dataset"],
364          iteration_type=["get_next", "for_loop"],
365          distribution=[
366              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
367              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
368          ],
369          enable_get_next_as_optional=[True, False]))
370  def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
371                                 distribution, enable_get_next_as_optional):
372    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
373                                              "/device:CPU:0"])]
374    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
375    dataset_or_input_fn = self._create_dataset_or_input_fn(
376        input_type, dataset_fn)
377
378    expected_values = [[i, i + 1] for i in range(0, 10, 2)]
379
380    distribution.extended.experimental_enable_get_next_as_optional = (
381        enable_get_next_as_optional)
382    self._test_input_iteration(input_type, api_type, iteration_type,
383                               dataset_or_input_fn, worker_device_pairs,
384                               expected_values, distribution)
385
386  @combinations.generate(
387      combinations.combine(
388          mode=["graph", "eager"],
389          input_type=["input_fn", "dataset"],
390          api_type=["wrap_into_iterator", "wrap_into_dataset"],
391          iteration_type=["get_next", "for_loop"],
392          distribution=[strategy_combinations.tpu_strategy],
393          enable_get_next_as_optional=[True, False]))
394  def testTPU(self, input_type, api_type, iteration_type, distribution,
395              enable_get_next_as_optional):
396    worker_device_pairs = collections.OrderedDict()
397    for tpu_device in distribution.extended.worker_devices:
398      host_device = device_util.get_host_for_device(tpu_device)
399      worker_device_pairs.setdefault(host_device, [])
400      worker_device_pairs[host_device].append(tpu_device)
401    worker_device_pairs = worker_device_pairs.items()
402    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
403    dataset_or_input_fn = self._create_dataset_or_input_fn(
404        input_type, dataset_fn)
405
406    expected_values = [[i, i + 1] for i in range(0, 10, 2)]
407
408    distribution.extended.experimental_enable_get_next_as_optional = (
409        enable_get_next_as_optional)
410    self._test_input_iteration(input_type, api_type, iteration_type,
411                               dataset_or_input_fn, worker_device_pairs,
412                               expected_values, distribution)
413
414  @combinations.generate(
415      combinations.combine(
416          mode=["graph", "eager"],
417          input_type=["input_fn", "dataset"],
418          api_type=["wrap_into_iterator", "wrap_into_dataset"],
419          iteration_type=["get_next", "for_loop"],
420          distribution=[
421              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
422              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
423          ],
424          enable_get_next_as_optional=[True, False]))
425  def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
426                       enable_get_next_as_optional):
427    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
428                                              "/device:CPU:0"])]
429
430    def dataset_fn(ctx):
431      del ctx
432      dataset1 = dataset_ops.Dataset.range(10)
433      dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
434      return dataset_ops.Dataset.zip((dataset1, dataset2))
435
436    dataset_or_input_fn = self._create_dataset_or_input_fn(
437        input_type, dataset_fn)
438
439    expected_values = [
440        [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2)
441    ]
442
443    distribution.extended.experimental_enable_get_next_as_optional = (
444        enable_get_next_as_optional)
445    self._test_input_iteration(input_type, api_type, iteration_type,
446                               dataset_or_input_fn, worker_device_pairs,
447                               expected_values, distribution)
448
449  @combinations.generate(
450      combinations.combine(
451          mode=["eager"],
452          input_type=["input_fn", "dataset"],
453          api_type=["wrap_into_dataset"],
454          iteration_type=["get_next", "for_loop"],
455          distribution=[
456              strategy_combinations.multi_worker_mirrored_2x2_gpu,
457              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
458          ],
459          enable_get_next_as_optional=[True, False]))
460  def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type,
461                                  distribution, enable_get_next_as_optional):
462    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
463                                              "/device:GPU:1"])]
464
465    def dataset_fn(ctx):
466      del ctx
467      dataset1 = dataset_ops.Dataset.range(10)
468      dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
469      return dataset_ops.Dataset.zip((dataset1, dataset2))
470
471    dataset_or_input_fn = self._create_dataset_or_input_fn(
472        input_type, dataset_fn)
473
474    expected_values = [
475        [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2)
476    ]
477
478    distribution.extended.experimental_enable_get_next_as_optional = (
479        enable_get_next_as_optional)
480
481    # Input_context is not passed in and thus no sharding.
482    self._test_input_iteration(input_type, api_type, iteration_type,
483                               dataset_or_input_fn, worker_device_pairs,
484                               expected_values, distribution)
485
486  @combinations.generate(
487      combinations.combine(
488          mode=["eager"],
489          distribution=[
490              strategy_combinations.one_device_strategy,
491              strategy_combinations.mirrored_strategy_with_one_cpu,
492              strategy_combinations.multi_worker_mirrored_2x1_cpu,
493          ]))
494  def testIterableIterator(self, distribution):
495    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
496    input_workers = input_lib.InputWorkers(worker_device_pairs)
497
498    dataset = dataset_ops.Dataset.range(10)
499    dist_dataset = input_util.get_distributed_dataset(dataset, input_workers,
500                                                      distribution)
501
502    iterator = iter(dist_dataset)
503    for i, element in enumerate(iterator):
504      self.assertAllEqual(distribution.experimental_local_results(element), [i])
505
506  @combinations.generate(
507      combinations.combine(
508          mode=["eager"],
509          distribution=[
510              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
511              strategy_combinations.mirrored_strategy_with_one_cpu,
512          ]))
513  def testIterableIteratorError(self, distribution):
514    dataset = dataset_ops.Dataset.range(10).batch(2)
515    dist_dataset = distribution.experimental_distribute_dataset(dataset)
516
517    iterator = iter(dist_dataset)
518    # Raises error when next(iterator) is called without strategy scope
519    with self.assertRaises(ValueError):
520
521      def replica_fn1(iterator):
522        return next(iterator)
523
524      distribution.run(replica_fn1, args=(iterator,))
525
526    if distribution.num_replicas_in_sync == 1:
527      expected_result = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8, 9]]]
528    elif distribution.num_replicas_in_sync == 2:
529      expected_result = [[[0], [1]], [[2], [3]], [[4], [5]], [[6], [7]],
530                         [[8], [9]]]
531
532    with distribution.scope():
533
534      def replica_fn2(iterator):
535        return iterator
536
537      result = distribution.run(replica_fn2, args=(next(iterator),))
538      self.assertAllEqual(
539          distribution.experimental_local_results(result), expected_result[0])
540
541    # Confirm default ReplicaContext also works
542    iterator = iter(dist_dataset)
543    for i, element in enumerate(iterator):
544      self.assertAllEqual(
545          distribution.experimental_local_results(element), expected_result[i])
546
547  @combinations.generate(
548      combinations.combine(
549          mode=["graph", "eager"],
550          input_type=["input_fn", "dataset"],
551          api_type=["wrap_into_iterator", "wrap_into_dataset"],
552          iteration_type=["get_next", "for_loop"],
553          drop_remainder=[True, False],
554          distribution=[
555              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
556              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
557          ]))
558  def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
559                               drop_remainder, distribution):
560    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
561                                              "/device:CPU:0"])]
562    dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(  # pylint: disable=g-long-lambda
563        2, drop_remainder=drop_remainder)
564    dataset_or_input_fn = self._create_dataset_or_input_fn(
565        input_type, dataset_fn)
566
567    # The last global batch only contains data for one replica.
568    if drop_remainder:
569      expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
570    else:
571      expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
572    distribution.extended.experimental_enable_get_next_as_optional = True
573    self._test_input_iteration(input_type, api_type, iteration_type,
574                               dataset_or_input_fn, worker_device_pairs,
575                               expected_values, distribution)
576
577  @combinations.generate(
578      combinations.combine(
579          mode=["eager"],
580          input_type=["input_fn", "dataset"],
581          api_type=["wrap_into_dataset"],
582          iteration_type=["get_next", "for_loop"],
583          drop_remainder=[True, False],
584          distribution=[
585              strategy_combinations.multi_worker_mirrored_2x1_cpu,
586              strategy_combinations.multi_worker_mirrored_2x1_gpu,
587          ]))
588  def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type,
589                                          iteration_type, drop_remainder,
590                                          distribution):
591    # Actual devices don't matter in this test as long as the number of global
592    # repices is 2.
593    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
594    cr = distribution.cluster_resolver
595    self.assertIsNotNone(cr)
596    worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
597                                                  cr.task_type)
598    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
599                                                    cr.task_type, cr.task_id)
600
601    def dataset_fn(_):
602      dataset = dataset_ops.Dataset.range(9)
603
604      if input_type == "input_fn":
605        # When input_fn is used, there is no automatic rebatching and sharding,
606        # so we add them here.
607        return dataset.shard(worker_count, id_in_cluster).batch(1)
608      else:
609        return dataset.batch(2, drop_remainder=drop_remainder)
610
611    dataset_or_input_fn = self._create_dataset_or_input_fn(
612        input_type, dataset_fn)
613
614    if drop_remainder and input_type == "dataset":
615      if id_in_cluster == 0:
616        expected_values = [[[0]], [[2]], [[4]], [[6]]]
617      else:
618        expected_values = [[[1]], [[3]], [[5]], [[7]]]
619    else:
620      # The last global batch only contains data for one replica.
621      if id_in_cluster == 0:
622        expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]]
623      else:
624        expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]]
625    distribution.extended.experimental_enable_get_next_as_optional = True
626    self._test_input_iteration(
627        input_type,
628        api_type,
629        iteration_type,
630        dataset_or_input_fn,
631        worker_device_pairs,
632        expected_values,
633        distribution,
634        num_replicas_in_sync=distribution.num_replicas_in_sync,
635        input_context=distribution.extended._make_input_context())
636
637  @combinations.generate(
638      combinations.combine(
639          mode=["eager"],
640          input_type=["input_fn", "dataset"],
641          api_type=["wrap_into_dataset"],
642          iteration_type=["get_next", "for_loop"],
643          drop_remainder=[True, False],
644          distribution=[
645              strategy_combinations.multi_worker_mirrored_2x2_gpu,
646              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
647          ]))
648  def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type,
649                                                      api_type, iteration_type,
650                                                      drop_remainder,
651                                                      distribution):
652    # Actual devices don't matter in this test as long as the number of global
653    # repices is 2.
654    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
655                                              "/device:GPU:1"])]
656    cr = distribution.cluster_resolver
657    self.assertIsNotNone(cr)
658    worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
659                                                  cr.task_type)
660    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
661                                                    cr.task_type, cr.task_id)
662
663    def dataset_fn(_):
664      dataset = dataset_ops.Dataset.range(15)
665
666      if input_type == "input_fn":
667        # When input_fn is used, there is no automatic rebatching and sharding,
668        # so we add them here.
669        return dataset.shard(worker_count, id_in_cluster).batch(1)
670      else:
671        return dataset.batch(4, drop_remainder=drop_remainder)
672
673    dataset_or_input_fn = self._create_dataset_or_input_fn(
674        input_type, dataset_fn)
675
676    # The last global batch only contains data for one replica.
677    if drop_remainder and input_type == "dataset":
678      if id_in_cluster == 0:
679        expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]]
680      else:
681        expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]]
682    else:
683      if id_in_cluster == 0:
684        expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]]
685      else:
686        expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]]
687    distribution.extended.experimental_enable_get_next_as_optional = True
688    self._test_input_iteration(
689        input_type,
690        api_type,
691        iteration_type,
692        dataset_or_input_fn,
693        worker_device_pairs,
694        expected_values,
695        distribution,
696        num_replicas_in_sync=distribution.num_replicas_in_sync,
697        input_context=distribution.extended._make_input_context())
698
699  @combinations.generate(
700      combinations.combine(
701          mode=["graph", "eager"],
702          input_type=["dataset"],
703          api_type=["wrap_into_iterator", "wrap_into_dataset"],
704          iteration_type=["get_next", "for_loop"],
705          num_replicas_in_sync=[None, 2],
706          distribution=[
707              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
708              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
709          ],
710          enable_get_next_as_optional=[True, False]))
711  def testBatchSplitting(self, input_type, api_type, iteration_type,
712                         num_replicas_in_sync, distribution,
713                         enable_get_next_as_optional):
714    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
715                                              "/device:CPU:0"])]
716    batch_size = 10
717    dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
718    dataset_or_input_fn = self._create_dataset_or_input_fn(
719        input_type, dataset_fn)
720
721    updated_batch_size = (
722        batch_size //
723        num_replicas_in_sync if num_replicas_in_sync else batch_size)
724    expected_values = [[
725        range(i, i + updated_batch_size),
726        range(i + updated_batch_size, i + 2 * updated_batch_size)
727    ] for i in range(0, 100, updated_batch_size * 2)]
728
729    distribution.extended.experimental_enable_get_next_as_optional = (
730        enable_get_next_as_optional)
731    self._test_input_iteration(
732        input_type,
733        api_type,
734        iteration_type,
735        dataset_or_input_fn,
736        worker_device_pairs,
737        expected_values,
738        distribution,
739        sess=None,
740        num_replicas_in_sync=num_replicas_in_sync)
741
742  @combinations.generate(
743      combinations.combine(
744          mode=["eager"],
745          input_type=["dataset"],
746          api_type=["wrap_into_dataset"],
747          iteration_type=["get_next", "for_loop"],
748          num_replicas_in_sync=[None, 2],
749          distribution=[
750              strategy_combinations.multi_worker_mirrored_2x2_gpu,
751              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
752          ],
753          enable_get_next_as_optional=[True, False]))
754  def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
755                                    num_replicas_in_sync, distribution,
756                                    enable_get_next_as_optional):
757    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
758                                              "/device:GPU:1"])]
759    batch_size = 10
760    cr = distribution.cluster_resolver
761    self.assertIsNotNone(cr)
762
763    def dataset_fn(_):
764      dataset = dataset_ops.Dataset.range(100).batch(batch_size)
765      return dataset
766
767    dataset_or_input_fn = self._create_dataset_or_input_fn(
768        input_type, dataset_fn)
769
770    updated_batch_size = (
771        batch_size //
772        num_replicas_in_sync if num_replicas_in_sync else batch_size)
773    expected_values = [
774        [  # pylint: disable=g-complex-comprehension
775            range(i, i + updated_batch_size),
776            range(i + updated_batch_size, i + 2 * updated_batch_size)
777        ] for i in range(0, 100, updated_batch_size * 2)
778    ]
779
780    distribution.extended.experimental_enable_get_next_as_optional = (
781        enable_get_next_as_optional)
782    self._test_input_iteration(
783        input_type,
784        api_type,
785        iteration_type,
786        dataset_or_input_fn,
787        worker_device_pairs,
788        expected_values,
789        distribution,
790        sess=None,
791        num_replicas_in_sync=num_replicas_in_sync)
792
793  @combinations.generate(
794      combinations.combine(
795          mode=["eager"],
796          distribution=[
797              strategy_combinations.one_device_strategy,
798              strategy_combinations.mirrored_strategy_with_one_cpu,
799              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
800              strategy_combinations.tpu_strategy,
801              strategy_combinations.central_storage_strategy_with_two_gpus,
802              strategy_combinations.multi_worker_mirrored_2x2_gpu,
803              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
804              strategy_combinations.multi_worker_mirrored_2x1_cpu,
805          ],
806      ))
807  def testCacheAcrossIteration(self, distribution):
808    if not tf2.enabled():
809      self.skipTest("Only V2 is supported.")
810
811    dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4)
812    dist_dataset = distribution.experimental_distribute_dataset(dataset)
813
814    first_epoch = list(
815        distribution.experimental_local_results(x) for x in dist_dataset)
816    second_epoch = list(
817        distribution.experimental_local_results(x) for x in dist_dataset)
818
819    self.assertAllEqual(first_epoch, second_epoch)
820
821  @combinations.generate(
822      combinations.combine(
823          mode=["eager"],
824          distribution=[
825              strategy_combinations.one_device_strategy,
826              strategy_combinations.mirrored_strategy_with_one_cpu,
827              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
828              strategy_combinations.tpu_strategy,
829              strategy_combinations.central_storage_strategy_with_two_gpus,
830              strategy_combinations.multi_worker_mirrored_2x2_gpu,
831              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
832              strategy_combinations.multi_worker_mirrored_2x1_cpu,
833          ],
834          reshuffle=[True, False]))
835  def testShuffleAcrossIterations(self, distribution, reshuffle):
836    if not tf2.enabled():
837      self.skipTest("Only V2 is supported.")
838
839    dataset = dataset_ops.Dataset.range(12).shuffle(
840        12, reshuffle_each_iteration=reshuffle).batch(4)
841    dist_dataset = distribution.experimental_distribute_dataset(dataset)
842
843    first_epoch = list(
844        distribution.experimental_local_results(x) for x in dist_dataset)
845    second_epoch = list(
846        distribution.experimental_local_results(x) for x in dist_dataset)
847
848    if reshuffle:
849      self.assertNotAllEqual(first_epoch, second_epoch)
850    else:
851      self.assertAllEqual(first_epoch, second_epoch)
852
853  @combinations.generate(
854      combinations.combine(
855          mode=["eager"],
856          distribution=[
857              strategy_combinations.one_device_strategy,
858              strategy_combinations.mirrored_strategy_with_one_cpu,
859              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
860              strategy_combinations.tpu_strategy,
861              strategy_combinations.central_storage_strategy_with_two_gpus,
862              strategy_combinations.multi_worker_mirrored_2x2_gpu,
863              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
864              strategy_combinations.multi_worker_mirrored_2x1_cpu,
865          ]))
866  def testGetNextOptionalShapeFinite(self, distribution):
867    batch_size = 8
868    dataset = dataset_ops.DatasetV2.from_tensor_slices({
869        "feature": array_ops.ones([batch_size, 10]),
870        "label": array_ops.ones([batch_size]),
871    })
872    dataset = dataset.batch(batch_size, drop_remainder=True)
873    dist_dataset = distribution.experimental_distribute_dataset(dataset)
874
875    @def_function.function
876    def train_fn():
877      for data in dist_dataset:
878        data = nest.map_structure(distribution.experimental_local_results, data)
879        feature = data["feature"]
880        label = data["label"]
881
882        # Assert the shapes are still static from all replicas.
883        for replica_id in range(len(distribution.extended.worker_devices)):
884          self.assertEqual([None, 10],
885                           feature[replica_id].shape.as_list())
886          self.assertEqual([None], label[replica_id].shape.as_list())
887
888    train_fn()
889
890  @combinations.generate(
891      combinations.combine(
892          mode=["eager"],
893          distribution=[
894              strategy_combinations.one_device_strategy,
895              strategy_combinations.mirrored_strategy_with_one_cpu,
896              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
897              strategy_combinations.tpu_strategy,
898              strategy_combinations.central_storage_strategy_with_two_gpus,
899              strategy_combinations.multi_worker_mirrored_2x2_gpu,
900              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
901              strategy_combinations.multi_worker_mirrored_2x1_cpu,
902          ]))
903  def testGetNextOptionalShapeInfinite(self, distribution):
904    batch_size = 8
905    dataset = dataset_ops.DatasetV2.from_tensor_slices({
906        "feature": array_ops.ones([batch_size, 10]),
907        "label": array_ops.ones([batch_size]),
908    })
909    dataset = dataset.batch(batch_size, drop_remainder=True)
910    dataset = dataset.repeat()
911    dist_dataset = distribution.experimental_distribute_dataset(dataset)
912    per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
913
914    @def_function.function
915    def train_fn():
916      data = iter(dist_dataset).get_next_as_optional().get_value()
917      data = nest.map_structure(distribution.experimental_local_results, data)
918      feature = data["feature"]
919      label = data["label"]
920
921      # Assert the shapes are still static from all replicas.
922      for replica_id in range(len(distribution.extended.worker_devices)):
923        self.assertEqual([per_replica_batch_size, 10],
924                         feature[replica_id].shape.as_list())
925        self.assertEqual([per_replica_batch_size],
926                         label[replica_id].shape.as_list())
927
928    train_fn()
929
930  @combinations.generate(
931      combinations.combine(
932          mode=["eager"],
933          distribution=[
934              strategy_combinations.one_device_strategy,
935              strategy_combinations.mirrored_strategy_with_one_cpu,
936              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
937              strategy_combinations.tpu_strategy,
938              strategy_combinations.central_storage_strategy_with_two_gpus,
939              strategy_combinations.multi_worker_mirrored_2x2_gpu,
940              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
941              strategy_combinations.multi_worker_mirrored_2x1_cpu,
942          ]))
943  def testGetNextOptionalShapeEmpty(self, distribution):
944    batch_size = 8
945    dataset = dataset_ops.DatasetV2.from_tensor_slices({
946        "feature": array_ops.ones([batch_size, 10]),
947        "label": array_ops.ones([batch_size]),
948    })
949    dataset = dataset.batch(batch_size, drop_remainder=True)
950    dataset = dataset.repeat()
951    dist_dataset = distribution.experimental_distribute_dataset(dataset)
952    per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
953
954    @def_function.function
955    def train_fn():
956      data = iter(dist_dataset).get_next_as_optional()
957      feature_specs = data.element_spec["feature"]._component_specs
958      value_specs = data.element_spec["label"]._component_specs
959      if not isinstance(feature_specs, tuple):
960        feature_specs = (feature_specs,)
961        value_specs = (value_specs,)
962      # Assert the shapes are still static from all replicas.
963      for replica_id in range(len(distribution.extended.worker_devices)):
964        self.assertEqual([per_replica_batch_size, 10],
965                         feature_specs[replica_id].shape.as_list())
966        self.assertEqual([per_replica_batch_size],
967                         value_specs[replica_id].shape.as_list())
968
969    train_fn()
970
971  @combinations.generate(
972      combinations.combine(
973          mode=["eager"],
974          distribution=[
975              strategy_combinations.multi_worker_mirrored_2x1_cpu,
976          ],
977          input_type=["dataset"],
978          api_type=["wrap_into_iterator", "wrap_into_dataset"],
979          iteration_type=["get_next", "for_loop"],
980          auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
981  def testAutoshardingOption(self, distribution, input_type, api_type,
982                             iteration_type, auto_shard_policy):
983    cr = distribution.cluster_resolver
984    self.assertIsNotNone(cr)
985    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
986                                                    cr.task_type, cr.task_id)
987    ds_option = options_lib.Options()
988    ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
989    dataset_fn = (
990        lambda _: dataset_ops.Dataset.range(4).with_options(ds_option))
991    dataset_or_input_fn = self._create_dataset_or_input_fn(
992        input_type, dataset_fn)
993
994    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
995    if auto_shard_policy == AutoShardPolicy.AUTO:
996      if id_in_cluster == 0:
997        expected_values = [[0], [2]]
998      else:
999        expected_values = [[1], [3]]
1000    else:
1001      expected_values = [[0], [1], [2], [3]]
1002    self._test_input_iteration(
1003        input_type,
1004        api_type,
1005        iteration_type,
1006        dataset_or_input_fn,
1007        worker_device_pairs,
1008        expected_values,
1009        distribution,
1010        input_context=distribution.extended._make_input_context())
1011
1012  @combinations.generate(
1013      combinations.combine(
1014          mode=["eager"],
1015          distribution=[
1016              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1017          ],
1018          input_type=["input_fn"],
1019          api_type=["wrap_into_dataset"],
1020          iteration_type=["get_next", "for_loop"]))
1021  def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type,
1022                                       iteration_type):
1023    cr = distribution.cluster_resolver
1024    self.assertIsNotNone(cr)
1025    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
1026                                                    cr.task_type, cr.task_id)
1027
1028    def dataset_fn(ctx):
1029      if ctx.input_pipeline_id == 0:
1030        return dataset_ops.Dataset.range(8).batch(2)
1031      else:
1032        return dataset_ops.Dataset.range(9).batch(2)
1033
1034    dataset_or_input_fn = self._create_dataset_or_input_fn(
1035        input_type, dataset_fn)
1036
1037    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1038
1039    if id_in_cluster == 0:
1040      expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]]
1041    else:
1042      expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]]
1043    distribution.extended.experimental_enable_get_next_as_optional = True
1044    self._test_input_iteration(input_type, api_type, iteration_type,
1045                               dataset_or_input_fn, worker_device_pairs,
1046                               expected_values, distribution)
1047
1048  @combinations.generate(
1049      combinations.combine(
1050          strategy=[
1051              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1052              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1053          ],
1054          mode=["eager"]))
1055  def testLoopOverDatasetInTFFunction(self, strategy):
1056    dataset = dataset_ops.Dataset.range(10).map(lambda x: {  # pylint: disable=g-long-lambda
1057        "y": math_ops.cast(x, dtypes.float32) ** 2,
1058    }).batch(4)
1059    dist_dataset = strategy.experimental_distribute_dataset(dataset)
1060
1061    with strategy.scope():
1062      v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM)
1063
1064    @def_function.function
1065    def iterator_fn(dist_dataset):
1066
1067      def assign_add_fn(data):
1068        v.assign_add(math_ops.reduce_sum(data["y"]))
1069
1070      for data in dist_dataset:
1071        strategy.run(assign_add_fn, args=(data,))
1072
1073    iterator_fn(dist_dataset)
1074    self.assertEqual(v.numpy(), 285.0)
1075
1076
1077class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
1078                                        parameterized.TestCase):
1079  """Tests for DistributedDataset with non-dense tensors."""
1080
1081  @combinations.generate(
1082      combinations.combine(
1083          mode=["eager"],
1084          distribution=[
1085              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1086              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1087              strategy_combinations.multi_worker_mirrored_2x2_gpu,
1088          ],
1089          input_type=["dataset", "input_fn"],
1090          drop_remainder=[False, True],
1091          defun_type=["lambda", "tf_function"],
1092      ))
1093  def testRaggedSparse(self, distribution, input_type, drop_remainder,
1094                       defun_type):
1095    """Test with `RaggedTensor`s and `SparseTensor`s."""
1096    self.skipTest("b/213596871, b/214574707")
1097
1098    if not tf2.enabled():
1099      self.skipTest("Only V2 is supported.")
1100
1101    defun = {
1102        "lambda": lambda f: f,
1103        "tf_function": def_function.function
1104    }[defun_type]
1105    distribution.extended.experimental_enable_get_next_as_optional = True
1106    global_batch_size = 8
1107
1108    def dataset_fn(ctx=None):
1109      ctx = ctx or distribute_lib.InputContext()
1110      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1111      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1112      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1113      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1114          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1115      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1116          "dense": ragged_tensor.to_tensor(),
1117          "ragged": ragged_tensor,
1118          "sparse": ragged_tensor.to_sparse(),
1119      })
1120      dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
1121      return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1122
1123    dataset_or_input_fn = self._create_dataset_or_input_fn(
1124        input_type, dataset_fn)
1125    dataset = self._wrap_dataset(input_type, dataset_or_input_fn,
1126                                 distribution.extended._input_workers,
1127                                 distribution.num_replicas_in_sync,
1128                                 distribution)
1129    # Assert that the tensors are rebatched and sparsity is preserved.
1130    per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
1131    self.assertAllEqual(
1132        distribute_utils.select_replica(0, per_replica_batch["dense"]),
1133        [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
1134    self.assertAllEqual(
1135        distribute_utils.select_replica(1, per_replica_batch["dense"]),
1136        [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
1137    # Transitively check the ragged and sparse tensors by densification.
1138    for i in range(2):
1139      self.assertLen(
1140          distribute_utils.select_replica(i,
1141                                          per_replica_batch["ragged"]).values,
1142          6)
1143      self.assertAllEqual(
1144          distribute_utils.select_replica(
1145              i, per_replica_batch["ragged"]).to_tensor(),
1146          distribute_utils.select_replica(i, per_replica_batch["dense"]))
1147      self.assertLen(
1148          distribute_utils.select_replica(i,
1149                                          per_replica_batch["sparse"]).indices,
1150          6)
1151      self.assertAllEqual(
1152          sparse_ops.sparse_tensor_to_dense(
1153              distribute_utils.select_replica(i, per_replica_batch["sparse"])),
1154          distribute_utils.select_replica(i, per_replica_batch["dense"]))
1155    # Iterate through all the batches and sum them up.
1156    def sum_batch(per_replica_features):
1157      """Sums the `PerReplica` values in the `per_replica_features` map."""
1158
1159      def map_fn(per_replica_values):
1160        per_replica_sums = distribution.run(
1161            (lambda x: math_ops.reduce_sum(x.values)) if all(
1162                map(sparse_tensor.is_sparse, per_replica_values.values)) else
1163            math_ops.reduce_sum, (per_replica_values,))
1164        return distribution.reduce(
1165            reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
1166
1167      return nest.map_structure(map_fn, per_replica_features)
1168
1169    def _reduce(state, batch):
1170      sums = sum_batch(batch)
1171      return {name: value + sums[name] for name, value in state.items()}
1172
1173    def sum_for_loop(dataset):
1174      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1175      for batch in dataset:
1176        sums = _reduce(sums, batch)
1177      return sums
1178
1179    def sum_while_loop(iterator, reduce_fn):
1180      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1181      while True:
1182        try:
1183          sums = reduce_fn(sums, iterator)
1184        except (StopIteration, errors.OutOfRangeError):
1185          return sums
1186
1187    while_sums = sum_while_loop(
1188        iter(dataset),
1189        defun(lambda state, iterator: _reduce(state, next(iterator))))
1190    self.assertAllEqual(
1191        nest.flatten(while_sums),
1192        # When there's no partial batch, the sum is smaller.
1193        [200. if drop_remainder else 310.] * 3)
1194    for_sums = defun(sum_for_loop)(dataset)
1195    # For loops always call get next as optional inside tf functions, so we
1196    # expect 310 here when using an input function (as there are 5 batches of
1197    # size 4 round robined over 2 replicas.
1198    expected_for_sum = 200.
1199    if (not drop_remainder or
1200        (defun_type == "tf_function" and input_type == "input_fn")):
1201      expected_for_sum = 310.
1202    self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
1203
1204  @combinations.generate(
1205      combinations.combine(
1206          mode=["eager"],
1207          distribution=[
1208              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1209              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1210              strategy_combinations.one_device_strategy,
1211              strategy_combinations.mirrored_strategy_with_one_cpu
1212          ],
1213          input_type=["dataset", "input_fn"],
1214          drop_remainder=[False, True],
1215          tensor_type=["sparse", "ragged"],
1216          enable_get_next_as_optional=[True, False]))
1217  def testRaggedSparseGetNextAsOptional(self, distribution, input_type,
1218                                        drop_remainder, tensor_type,
1219                                        enable_get_next_as_optional):
1220    """Test with `RaggedTensor`s and `SparseTensor`s."""
1221    if not tf2.enabled():
1222      self.skipTest("Only V2 is supported.")
1223
1224    distribution.extended.experimental_enable_get_next_as_optional = (
1225        enable_get_next_as_optional)
1226    global_batch_size = 8
1227
1228    def dataset_fn(ctx=None):
1229      ctx = ctx or distribute_lib.InputContext()
1230      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1231      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1232      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1233      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1234          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1235      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1236          tensor_type: (ragged_tensor if tensor_type == "ragged" else
1237                        ragged_tensor.to_sparse()),
1238      })
1239      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1240      return dataset.batch(batch_size, drop_remainder=drop_remainder)
1241
1242    if input_type == "dataset":
1243      ds = distribution.experimental_distribute_dataset(
1244          dataset_fn(distribute_lib.InputContext()))
1245    else:
1246      ds = distribution.distribute_datasets_from_function(dataset_fn)
1247    iterator = iter(ds)
1248
1249    self.assertEqual(iterator._enable_get_next_as_optional,
1250                     (not drop_remainder) and enable_get_next_as_optional)
1251
1252  @combinations.generate(
1253      combinations.combine(
1254          tf_api_version=2,
1255          mode=["eager"],
1256          distribution=[
1257              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1258              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1259              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1260              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1261              strategy_combinations.multi_worker_mirrored_2x2_gpu,
1262          ],
1263          input_type=["dataset", "input_fn"],
1264          drop_remainder=[False, True],
1265      ))
1266  def testRaggedSparseGetNextAsOptionalInLoop(self, distribution, input_type,
1267                                              drop_remainder):
1268    """Test with `RaggedTensor`s and `SparseTensor`s."""
1269    global_batch_size = 8
1270
1271    def dataset_fn(ctx=None):
1272      ctx = ctx or distribute_lib.InputContext()
1273      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1274      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1275      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1276      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1277          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1278      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1279          "dense": ragged_tensor.to_tensor(),
1280          "ragged": ragged_tensor,
1281          "sparse": ragged_tensor.to_sparse(),
1282      })
1283      dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
1284      return dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1285
1286    if input_type == "dataset":
1287      ds = distribution.experimental_distribute_dataset(
1288          dataset_fn(distribute_lib.InputContext()))
1289    else:
1290      ds = distribution.distribute_datasets_from_function(dataset_fn)
1291
1292    # Iterate through all the batches and sum them up.
1293    def sum_batch(per_replica_features):
1294      """Sums the `PerReplica` values in the `per_replica_features` map."""
1295
1296      def map_fn(per_replica_values):
1297
1298        def _sum(value):
1299          if sparse_tensor.is_sparse(value):
1300            return math_ops.reduce_sum(value.values)
1301          else:
1302            return math_ops.reduce_sum(value)
1303
1304        per_replica_sums = distribution.run(_sum, args=(per_replica_values,))
1305        return distribution.reduce(
1306            reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
1307
1308      return nest.map_structure(map_fn, per_replica_features)
1309
1310    def _reduce(state, batch):
1311      sums = sum_batch(batch)
1312      return {name: value + sums[name] for name, value in state.items()}
1313
1314    def sum_while_loop(ds):
1315      iterator = iter(ds)
1316      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1317      try_next = constant_op.constant(True)
1318
1319      while try_next:
1320        opt_iterate = iterator.get_next_as_optional()
1321        if opt_iterate.has_value():
1322          sums = _reduce(sums, opt_iterate.get_value())
1323        else:
1324          try_next = False
1325      return sums
1326
1327    sums = def_function.function(sum_while_loop)(ds)
1328    # For loops always call get next as optional inside tf functions, so we
1329    # expect 310 here when using an input function (as there are 5 batches of
1330    # size 4 round robined over 2 replicas.
1331    expected_for_sum = 200.
1332    if not drop_remainder or input_type == "input_fn":
1333      expected_for_sum = 310.
1334    self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
1335
1336  @combinations.generate(
1337      combinations.combine(
1338          mode=["eager"],
1339          input_type=["dataset"],
1340          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1341          iteration_type=["get_next", "for_loop"],
1342          distribution=[
1343              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1344              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1345          ]))
1346  def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
1347                           distribution):
1348    # Test case: 2 workers, 1 replica each.
1349    # This test simulates the sharded behavior when we have two files each with
1350    # 12 elements and a global batch size of 8. When we consider the dataset in
1351    # aggregate (non-distributed), there are 24 elements divided into 3 batches
1352    # of size 8. Hence, the correct distributed behavior is for each replica to
1353    # see sub-batches of size 4, over three steps.
1354    def dataset_fn(ctx):
1355      del ctx
1356      dataset = dataset_ops.Dataset.range(12).batch(8)
1357
1358      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1359      # `dataset` defines the per-worker dataset and will not be further
1360      # sharded. Each worker will see a dataset that is
1361      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1362      options = options_lib.Options()
1363      options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1364      dataset = dataset.with_options(options)
1365      return dataset
1366
1367    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1368
1369    # Actual devices don't matter in this test as long as there is 1 local
1370    # replica.
1371    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1372
1373    # Each test runs individually on each worker, so we compare the
1374    # values on each worker. Each worker should rebatch its dataset into
1375    # smaller batches of size 4.
1376    expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
1377    self._test_input_iteration(
1378        input_type,
1379        api_type,
1380        iteration_type,
1381        dataset,
1382        worker_device_pairs,
1383        expected_values,
1384        distribution,
1385        num_replicas_in_sync=distribution.num_replicas_in_sync,
1386        input_context=distribution.extended._make_input_context())
1387
1388  @combinations.generate(
1389      combinations.combine(
1390          mode=["eager"],
1391          input_type=["dataset"],
1392          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1393          iteration_type=["get_next", "for_loop"],
1394          distribution=[
1395              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1396              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1397          ]))
1398  def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
1399                                            iteration_type, distribution):
1400    # Test case: 2 workers, 1 replica each.
1401    # This test simulates the sharded behavior when we have two files each with
1402    # 12 elements and a global batch size of 8. When we consider the dataset in
1403    # aggregate (non-distributed), there are 24 elements divided into 3 batches
1404    # of size 8. Hence, the correct distributed behavior is for each replica to
1405    # see sub-batches of size 4, over three steps. However, when we create a
1406    # DistributedDataset and cannot statically infer the intended global batch
1407    # size (e.g. if the user does not use a batching dataset), each worker will
1408    # rebatch based on the dynamic batch size of the data encountered, even when
1409    # it encounters partial batches. The last per-worker partial batch (size 4)
1410    # ends up being split into two replicas, resulting in 4 steps in total, of
1411    # (global) batch sizes 8, 8, 4, 4.
1412    def dataset_fn(ctx):
1413      del ctx
1414      # The following dataset is equivalent to
1415      # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
1416      # This causes DistributedDataset to use LegacyRebatch instead.
1417      batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
1418      offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
1419      dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
1420
1421      def map_fn(offset, batch_size):
1422        return math_ops.range(offset, offset + batch_size)
1423
1424      dataset = dataset.map(map_fn)
1425
1426      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1427      # `dataset` defines the per-worker dataset and will not be further
1428      # sharded. Each worker will see a dataset that is equivalent to
1429      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1430      options = options_lib.Options()
1431      options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1432      dataset = dataset.with_options(options)
1433      return dataset
1434
1435    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1436
1437    # Actual devices don't matter in this test as long as the number of global
1438    # replicas is 2.
1439    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1440
1441    # Each test runs individually on each worker, so we compare the
1442    # values on each worker. Each worker should rebatch its dataset into
1443    # smaller batches of size 4.
1444    expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
1445    self._test_input_iteration(
1446        input_type,
1447        api_type,
1448        iteration_type,
1449        dataset,
1450        worker_device_pairs,
1451        expected_values,
1452        distribution,
1453        num_replicas_in_sync=distribution.num_replicas_in_sync,
1454        input_context=distribution.extended._make_input_context())
1455
1456  @combinations.generate(
1457      combinations.combine(
1458          mode=["eager"],
1459          input_type=["dataset"],
1460          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1461          iteration_type=["get_next", "for_loop"],
1462          distribution=[
1463              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1464              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1465          ],
1466          auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
1467  def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
1468                               distribution, auto_shard_policy):
1469    # Test case: 2 workers, 1 replica each.
1470    # This test simulates the sharded behavior the dataset is sharded by data
1471    # and the batch size is indivisible by the number of replicas. This checks
1472    # that the elements are as expected and the batch size across all workers
1473    # adds up to 3. This test will only pass if the autoshard rewrite rewrites
1474    # RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
1475    def dataset_fn(ctx):
1476      del ctx
1477      dataset = dataset_ops.Dataset.range(8).batch(3)
1478
1479      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1480      # `dataset` defines the per-worker dataset and will not be further
1481      # sharded. Each worker will see a dataset that is
1482      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1483      options = options_lib.Options()
1484      options.experimental_distribute.auto_shard_policy = auto_shard_policy
1485      dataset = dataset.with_options(options)
1486      return dataset
1487
1488    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1489
1490    # Actual devices don't matter in this test as long as there is 1 local
1491    # replica.
1492    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1493
1494    # Each test runs individually on each worker, so we compare the
1495    # values on each worker. We expect each worker to see different shards of
1496    # data.
1497    cr = distribution.cluster_resolver
1498    worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
1499                                                cr.task_id)
1500
1501    if worker_id == 0:
1502      expected_values = [[[0, 1]], [[3, 4]], [[6]]]
1503    elif worker_id == 1:
1504      expected_values = [[[2]], [[5]], [[7]]]
1505
1506    self._test_input_iteration(
1507        input_type,
1508        api_type,
1509        iteration_type,
1510        dataset,
1511        worker_device_pairs,
1512        expected_values,
1513        distribution,
1514        num_replicas_in_sync=distribution.num_replicas_in_sync,
1515        input_context=distribution.extended._make_input_context())
1516
1517
1518@framework_test_util.with_eager_op_as_function
1519class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
1520                                       parameterized.TestCase):
1521  """Tests for PER_WORKER and PER_REPLICA's InputOptions variants."""
1522
1523  def setUp(self):
1524    context._reset_context()
1525    strategy_combinations.set_virtual_cpus_to_at_least(3)
1526    super(DistributedIteratorPerDeviceTest, self).setUp()
1527
1528  @combinations.generate(
1529      combinations.combine(
1530          input_options=[
1531              distribute_lib.InputOptions(
1532                  experimental_place_dataset_on_device=False,
1533                  experimental_fetch_to_device=True,
1534                  experimental_replication_mode=distribute_lib
1535                  .InputReplicationMode.PER_WORKER),
1536              distribute_lib.InputOptions(
1537                  experimental_place_dataset_on_device=False,
1538                  experimental_fetch_to_device=True,
1539                  experimental_replication_mode=distribute_lib
1540                  .InputReplicationMode.PER_REPLICA),
1541          ],
1542          mode=["eager"],
1543          distribution=[
1544              strategy_combinations.mirrored_strategy_with_two_gpus,
1545              strategy_combinations
1546              .mirrored_strategy_with_two_gpus_no_merge_call,
1547              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1548              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1549          ]))
1550  def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution,
1551                                                        input_options):
1552
1553    def dataset_fn(input_context):  # pylint: disable=[unused-argument]
1554      return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4])
1555
1556    ds = distribution.experimental_distribute_datasets_from_function(
1557        dataset_fn, input_options)
1558
1559    for x in ds:
1560      assert x.values[0].device == distribution.extended.worker_devices[0]
1561      assert x.values[0].backing_device == distribution.extended.worker_devices[
1562          0]
1563      assert x.values[1].device == distribution.extended.worker_devices[1]
1564      assert x.values[1].backing_device == distribution.extended.worker_devices[
1565          1]
1566
1567  @combinations.generate(
1568      combinations.combine(
1569          distribution=[
1570              strategy_combinations.mirrored_strategy_with_two_gpus,
1571              strategy_combinations
1572              .mirrored_strategy_with_two_gpus_no_merge_call,
1573              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1574              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1575          ],
1576          input_options=[
1577              distribute_lib.InputOptions(
1578                  experimental_place_dataset_on_device=False,
1579                  experimental_fetch_to_device=False,
1580                  experimental_replication_mode=distribute_lib
1581                  .InputReplicationMode.PER_WORKER)
1582          ],
1583          mode=["eager"],
1584      ))
1585  def testDevicePlacementForPerWorkerValuesWithoutPrefetch(
1586      self, distribution, input_options):
1587
1588    def dataset_fn(input_context):
1589      return dataset_ops.Dataset.from_tensor_slices(
1590          np.full(4, input_context.input_pipeline_id))
1591
1592    ds = distribution.experimental_distribute_datasets_from_function(
1593        dataset_fn, input_options)
1594
1595    for x in ds:
1596      x = distribution.run(lambda inputs: inputs, args=(x,))
1597      assert x.values[
1598          0].device == "/job:localhost/replica:0/task:0/device:CPU:0"
1599      assert x.values[
1600          0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
1601      assert x.values[
1602          1].device == "/job:localhost/replica:0/task:0/device:CPU:0"
1603      assert x.values[
1604          1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
1605
1606  @combinations.generate(
1607      combinations.combine(
1608          input_options=[
1609              distribute_lib.InputOptions(
1610                  experimental_place_dataset_on_device=True,
1611                  experimental_fetch_to_device=False,
1612                  experimental_replication_mode=distribute_lib
1613                  .InputReplicationMode.PER_WORKER),
1614              distribute_lib.InputOptions(
1615                  experimental_place_dataset_on_device=True,
1616                  experimental_fetch_to_device=True,
1617                  experimental_replication_mode=distribute_lib
1618                  .InputReplicationMode.PER_REPLICA)
1619          ],
1620          mode=["eager"],
1621          distribution=[
1622              strategy_combinations.mirrored_strategy_with_two_gpus,
1623              strategy_combinations
1624              .mirrored_strategy_with_two_gpus_no_merge_call,
1625              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1626              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1627          ]))
1628  def testDevicePlacementForInvalidCombinations(self, distribution,
1629                                                input_options):
1630
1631    def dataset_fn(input_context):
1632      return dataset_ops.Dataset.from_tensor_slices(
1633          np.full(4, input_context.input_pipeline_id))
1634
1635    with self.assertRaises(ValueError):
1636      distribution.experimental_distribute_datasets_from_function(
1637          dataset_fn, input_options)
1638
1639  @combinations.generate(
1640      combinations.combine(
1641          input_options=[
1642              distribute_lib.InputOptions(
1643                  experimental_place_dataset_on_device=False,
1644                  experimental_fetch_to_device=False,
1645                  experimental_per_replica_buffer_size=2),
1646              distribute_lib.InputOptions(
1647                  experimental_place_dataset_on_device=False,
1648                  experimental_fetch_to_device=True,
1649                  experimental_per_replica_buffer_size=2),
1650          ],
1651          mode=["eager"],
1652          distribution=[
1653              strategy_combinations.mirrored_strategy_with_two_gpus,
1654              strategy_combinations
1655              .mirrored_strategy_with_two_gpus_no_merge_call,
1656              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1657              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1658          ]))
1659  def testPrefetchBufferSizeInputOptions(self, distribution, input_options):
1660
1661    def dataset_fn(input_context):
1662      return dataset_ops.Dataset.from_tensor_slices(
1663          np.arange(1, 11).reshape(
1664              (2, 5)) * (input_context.input_pipeline_id + 1))
1665
1666    ds = distribution.experimental_distribute_datasets_from_function(
1667        dataset_fn, input_options)
1668
1669    # validating the values
1670    x = next(iter(ds))
1671    assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
1672    assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
1673
1674  @combinations.generate(
1675      combinations.combine(
1676          input_options=[
1677              distribute_lib.InputOptions(
1678                  experimental_place_dataset_on_device=False,
1679                  experimental_fetch_to_device=False,
1680                  experimental_replication_mode=distribute_lib
1681                  .InputReplicationMode.PER_WORKER),
1682              distribute_lib.InputOptions(
1683                  experimental_place_dataset_on_device=False,
1684                  experimental_fetch_to_device=True,
1685                  experimental_replication_mode=distribute_lib
1686                  .InputReplicationMode.PER_WORKER),
1687          ],
1688          mode=["eager"],
1689          distribution=[
1690              strategy_combinations.mirrored_strategy_with_two_gpus,
1691              strategy_combinations
1692              .mirrored_strategy_with_two_gpus_no_merge_call,
1693              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1694              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1695          ]))
1696  def testOutputValuesForPerWorkerInputOptions(self, distribution,
1697                                               input_options):
1698
1699    def dataset_fn(input_context):
1700      return dataset_ops.Dataset.from_tensor_slices(
1701          np.arange(1, 11).reshape(
1702              (2, 5)) * (input_context.input_pipeline_id + 1))
1703
1704    ds = distribution.experimental_distribute_datasets_from_function(
1705        dataset_fn, input_options)
1706
1707    # validating the values
1708    x = next(iter(ds))
1709    assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
1710    assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
1711
1712  @combinations.generate(
1713      combinations.combine(
1714          input_options=[
1715              distribute_lib.InputOptions(
1716                  experimental_place_dataset_on_device=True,
1717                  experimental_fetch_to_device=False,
1718                  experimental_replication_mode=distribute_lib
1719                  .InputReplicationMode.PER_REPLICA),
1720              distribute_lib.InputOptions(
1721                  experimental_place_dataset_on_device=False,
1722                  experimental_fetch_to_device=False,
1723                  experimental_replication_mode=distribute_lib
1724                  .InputReplicationMode.PER_REPLICA),
1725              distribute_lib.InputOptions(
1726                  experimental_place_dataset_on_device=False,
1727                  experimental_fetch_to_device=True,
1728                  experimental_replication_mode=distribute_lib
1729                  .InputReplicationMode.PER_REPLICA),
1730          ],
1731          mode=["eager"],
1732          distribution=[
1733              strategy_combinations.mirrored_strategy_with_two_gpus,
1734              strategy_combinations
1735              .mirrored_strategy_with_two_gpus_no_merge_call,
1736              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1737              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1738          ]))
1739  def testOutputValuesForPerReplicaInputOptions(self, distribution,
1740                                                input_options):
1741
1742    def dataset_fn(input_context):
1743      return dataset_ops.Dataset.from_tensor_slices(
1744          np.arange(1, 10) * (input_context.input_pipeline_id + 1))
1745
1746    ds = distribution.experimental_distribute_datasets_from_function(
1747        dataset_fn, input_options)
1748    expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
1749    for i, x in enumerate(ds):
1750      # validating the values
1751      assert x.values[0].numpy() == expected[i]
1752      assert x.values[1].numpy() == expected[i] * 2
1753      loop_num = i
1754    assert loop_num == len(expected) - 1
1755
1756
1757class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
1758                                           parameterized.TestCase):
1759  """Tests for distributed iterators which read from tf.data service."""
1760
1761  def setUp(self):
1762    super(DistributedIteratorTfDataServiceTest, self).setUp()
1763    self.num_workers = 3
1764    if combinations.in_main_process():
1765      self.dispatcher = server_lib.DispatchServer()
1766      self.workers = []
1767      for _ in range(self.num_workers):
1768        self.workers.append(
1769            server_lib.WorkerServer(
1770                server_lib.WorkerConfig(
1771                    dispatcher_address=self.dispatcher.target.split("://")[1],
1772                    heartbeat_interval_ms=100,
1773                    dispatcher_timeout_ms=1000)))
1774      combinations.env().tf_data_service_dispatcher = self.dispatcher.target
1775
1776  @combinations.generate(
1777      combinations.combine(
1778          mode=["eager"],
1779          distribution=[
1780              strategy_combinations.one_device_strategy,
1781              strategy_combinations.mirrored_strategy_with_one_cpu,
1782              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1783              strategy_combinations.tpu_strategy,
1784              strategy_combinations.central_storage_strategy_with_two_gpus,
1785              strategy_combinations.multi_worker_mirrored_2x2_gpu,
1786              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
1787              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1788          ]))
1789  def testTfDataService(self, distribution):
1790    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1791    input_workers = input_lib.InputWorkers(worker_device_pairs)
1792
1793    dataset = dataset_ops.Dataset.range(1, 50)
1794    dataset = dataset.apply(
1795        data_service_ops._distribute(
1796            processing_mode=data_service_ops.ShardingPolicy.OFF,
1797            service=combinations.env().tf_data_service_dispatcher,
1798            job_name="foo"))
1799
1800    dist_dataset = input_util.get_distributed_dataset(dataset, input_workers,
1801                                                      distribution)
1802    iterator = iter(dist_dataset)
1803    results = []
1804    for element in iterator:
1805      local_results = distribution.experimental_local_results(element)
1806      for result in local_results:
1807        # input_lib.distributed_dataset may add extra '0' elements to pad
1808        # per-replica results.
1809        if result.numpy() != 0:
1810          results.append(result.numpy())
1811    self.assertNotEmpty(results)
1812    gathered = distribution.gather(constant_op.constant(results), axis=0)
1813    self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
1814
1815    histogram_proto = (
1816        input_lib._distributed_dataset_initialization_time_milliseconds
1817        .get_cell(distribution.__class__.__name__, "1").value())
1818    self.assertGreater(histogram_proto.num, 0.0)
1819
1820  @combinations.generate(
1821      combinations.combine(
1822          mode=["eager"],
1823          distribution=[
1824              strategy_combinations.one_device_strategy,
1825              strategy_combinations.mirrored_strategy_with_one_cpu,
1826              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1827              strategy_combinations.tpu_strategy,
1828              strategy_combinations.central_storage_strategy_with_two_gpus,
1829              strategy_combinations.multi_worker_mirrored_2x2_gpu,
1830              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
1831              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1832          ]))
1833  def testDistributeDatasetFromFunction(self, distribution):
1834    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1835    input_workers = input_lib.InputWorkers(worker_device_pairs)
1836    input_contexts = []
1837    num_workers = input_workers.num_workers
1838    for i in range(num_workers):
1839      input_contexts.append(distribute_lib.InputContext(
1840          num_input_pipelines=num_workers,
1841          input_pipeline_id=i,
1842          num_replicas_in_sync=num_workers))
1843
1844    dataset = dataset_ops.Dataset.range(1, 50)
1845    dataset_id = data_service_ops.register_dataset(
1846        service=combinations.env().tf_data_service_dispatcher,
1847        dataset=dataset)
1848
1849    def dataset_fn(input_context):
1850      del input_context
1851      return data_service_ops.from_dataset_id(
1852          processing_mode=data_service_ops.ShardingPolicy.OFF,
1853          service=combinations.env().tf_data_service_dispatcher,
1854          dataset_id=dataset_id,
1855          element_spec=dataset.element_spec,
1856          job_name="shared_job")
1857
1858    dist_dataset = input_util.get_distributed_datasets_from_function(
1859        dataset_fn, input_workers, input_contexts, distribution)
1860
1861    iterator = iter(dist_dataset)
1862    results = []
1863    for element in iterator:
1864      local_results = distribution.experimental_local_results(element)
1865      for result in local_results:
1866        # input_lib.distributed_dataset may add extra '0' elements to pad
1867        # per-replica results.
1868        if result.numpy() != 0:
1869          results.append(result.numpy())
1870    self.assertNotEmpty(results)
1871    gathered = distribution.gather(constant_op.constant(results), axis=0)
1872    self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
1873    histogram_proto = (
1874        input_lib
1875        ._distributed_dataset_from_function_initialization_time_milliseconds
1876        .get_cell(distribution.__class__.__name__, "1").value())
1877    self.assertGreater(histogram_proto.num, 0.0)
1878
1879  @combinations.generate(
1880      combinations.combine(
1881          mode=["eager"],
1882          distribution=[
1883              strategy_combinations.one_device_strategy,
1884              strategy_combinations.mirrored_strategy_with_one_cpu,
1885              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1886              strategy_combinations.mirrored_strategy_with_two_gpus,
1887              strategy_combinations.tpu_strategy,
1888              strategy_combinations.central_storage_strategy_with_two_gpus,
1889              strategy_combinations.multi_worker_mirrored_2x2_gpu,
1890              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
1891              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1892          ]))
1893  def testDistributeDatasetFromFunctionNested(self, distribution):
1894    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1895    input_workers = input_lib.InputWorkers(worker_device_pairs)
1896    input_contexts = []
1897    num_workers = input_workers.num_workers
1898    for i in range(num_workers):
1899      input_contexts.append(
1900          distribute_lib.InputContext(
1901              num_input_pipelines=num_workers,
1902              input_pipeline_id=i,
1903              num_replicas_in_sync=num_workers))
1904
1905    class InnerType(extension_type.ExtensionType):
1906      tensor: ops.Tensor
1907
1908    class OuterType(extension_type.ExtensionType):
1909      inner: InnerType
1910
1911    def dataset_fn(input_context):
1912      del input_context
1913
1914      def data_fn(batch_id) -> OuterType:
1915        del batch_id
1916
1917        return OuterType(
1918            inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]])))
1919
1920      return dataset_ops.Dataset.range(1, 10).map(data_fn)
1921
1922    dist_dataset = input_util.get_distributed_datasets_from_function(
1923        dataset_fn, input_workers, input_contexts, distribution)
1924
1925    iterator = iter(dist_dataset)
1926    results = []
1927    for element in iterator:
1928      local_results = distribution.experimental_local_results(element)
1929      for result in local_results:
1930        results.append(result)
1931
1932    expect_component = OuterType(
1933        inner=InnerType(tensor=constant_op.constant([[0., 1.], [2., 3.]])))
1934    self.assertCountEqual(
1935        num_workers * [expect_component for _ in range(1, 10)], results)
1936
1937if __name__ == "__main__":
1938  test_util.main()
1939