xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/input_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for input pipeline modifications for distribution strategies."""
16
17import os
18
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.data.ops import readers
21from tensorflow.python.data.util import structure
22from tensorflow.python.distribute import input_ops
23from tensorflow.python.eager import context
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import test_util
26from tensorflow.python.lib.io import python_io
27from tensorflow.python.ops import gen_dataset_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.platform import test
30from tensorflow.python.util import compat
31
32
33class AutoShardDatasetTest(test.TestCase):
34
35  def setUp(self):
36    super(AutoShardDatasetTest, self).setUp()
37    self._num_files = 10
38    self._num_records = 4
39    self._num_shards = 2
40    self._shard_index = 0
41    self._record_bytes = 10
42
43  def _getNext(self, dataset):
44    if context.executing_eagerly():
45      iterator = iter(dataset)
46      return iterator._next_internal  # pylint: disable=protected-access
47    else:
48      iterator = dataset_ops.make_one_shot_iterator(dataset)
49      get_next = iterator.get_next()
50      return lambda: get_next
51
52  def _record(self, r, f):
53    return compat.as_bytes("Record %d of file %d" % (r, f))
54
55  def _text_line(self, r, f):
56    return compat.as_bytes("Text line %d of file %d" % (r, f))
57
58  def _fixed_length_record(self, r, f):
59    return compat.as_bytes(str((r * f) % 10) * self._record_bytes)
60
61  def _createTFRecordFiles(self):
62    filenames = []
63    for i in range(self._num_files):
64      fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
65      filenames.append(fn)
66      writer = python_io.TFRecordWriter(fn)
67      for j in range(self._num_records):
68        record = self._record(j, i)
69        writer.write(record)
70      writer.close()
71    return filenames
72
73  def _createTextFiles(self):
74    filenames = []
75    for i in range(self._num_files):
76      fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
77      filenames.append(fn)
78      contents = []
79      for j in range(self._num_records):
80        contents.append(self._text_line(j, i))
81        if j + 1 != self._num_records or i == 0:
82          contents.append(b"\r\n")
83      contents = b"".join(contents)
84
85      with open(fn, "wb") as f:
86        f.write(contents)
87    return filenames
88
89  def _createFixedLengthRecordFiles(self):
90    filenames = []
91    for i in range(self._num_files):
92      fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
93      filenames.append(fn)
94      with open(fn, "wb") as f:
95        for j in range(self._num_records):
96          f.write(self._fixed_length_record(j, i))
97    return filenames
98
99  def _verifySimpleShardingOutput(self, dataset, record_fn):
100    next_element_fn = self._getNext(dataset)
101    with self.cached_session():
102      for f in range(self._shard_index, self._num_files, self._num_shards):
103        for r in range(self._num_records):
104          self.assertAllEqual(record_fn(r, f), self.evaluate(next_element_fn()))
105      with self.assertRaises(errors.OutOfRangeError):
106        self.evaluate(next_element_fn())
107
108  @test_util.run_in_graph_and_eager_modes
109  def testTFRecordDataset(self):
110    dataset = readers.TFRecordDataset(self._createTFRecordFiles())
111    dataset = input_ops.auto_shard_dataset(
112        dataset, self._num_shards, self._shard_index)
113
114    self._verifySimpleShardingOutput(dataset, self._record)
115
116  @test_util.run_in_graph_and_eager_modes
117  def testFlatMap(self):
118    dataset = dataset_ops.Dataset.from_tensor_slices(
119        self._createTFRecordFiles())
120    dataset = dataset.flat_map(readers.TFRecordDataset)
121    dataset = input_ops.auto_shard_dataset(
122        dataset, self._num_shards, self._shard_index)
123
124    self._verifySimpleShardingOutput(dataset, self._record)
125
126  @test_util.run_in_graph_and_eager_modes
127  def testInterleave(self):
128    dataset = dataset_ops.Dataset.from_tensor_slices(
129        self._createTFRecordFiles())
130    dataset = dataset.interleave(
131        readers.TFRecordDataset, cycle_length=4, block_length=self._num_records)
132    dataset = input_ops.auto_shard_dataset(
133        dataset, self._num_shards, self._shard_index)
134
135    # Since block_length == num records in each file, the output will still
136    # contain records in order of files.
137    self._verifySimpleShardingOutput(dataset, self._record)
138
139  @test_util.run_in_graph_and_eager_modes
140  def testListfiles(self):
141    filenames = self._createTFRecordFiles()
142    file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
143    dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
144    dataset = dataset.flat_map(readers.TFRecordDataset)
145    dataset = input_ops.auto_shard_dataset(
146        dataset, self._num_shards, self._shard_index)
147
148    next_element_fn = self._getNext(dataset)
149    actual, expected = [], []
150    for f in range(self._shard_index, self._num_files, self._num_shards):
151      for r in range(self._num_records):
152        actual.append(self.evaluate(next_element_fn()))
153        expected.append(self._record(r, f))
154    with self.assertRaises(errors.OutOfRangeError):
155      self.evaluate(next_element_fn())
156    self.assertAllEqual(expected, actual)
157
158  @test_util.run_in_graph_and_eager_modes
159  def testComplexPipeline(self):
160    # Setup a complex input pipeline.
161    batch_size = 2
162    num_epochs = 5
163    dataset = dataset_ops.Dataset.from_tensor_slices(
164        self._createTFRecordFiles())
165    dataset = dataset.shuffle(buffer_size=self._num_files)
166    dataset = dataset.flat_map(readers.TFRecordDataset)
167    dataset = dataset.prefetch(buffer_size=batch_size)
168    dataset = dataset.shuffle(2 * self._num_files * self._num_records)
169    dataset = dataset.repeat(num_epochs)
170    dataset = dataset.map(lambda x: x)
171    dataset = dataset.batch(batch_size)
172    dataset = dataset.prefetch(buffer_size=None)
173
174    # Auto shard.
175    dataset = input_ops.auto_shard_dataset(
176        dataset, self._num_shards, self._shard_index)
177
178    # Verify output.
179    next_element_fn = self._getNext(dataset)
180    actual = []
181    num_iterations = (self._num_files * self._num_records * num_epochs) // (
182        self._num_shards * batch_size)
183    for _ in range(num_iterations):
184      actual.extend(self.evaluate(next_element_fn()))
185    with self.assertRaises(errors.OutOfRangeError):
186      self.evaluate(next_element_fn())
187
188    expected = []
189    for f in range(0, self._num_files, self._num_shards):
190      for r in range(self._num_records):
191        expected.append(self._record(r, f))
192    expected *= num_epochs
193
194    self.assertAllEqual(sorted(expected), sorted(actual))
195
196  @test_util.run_in_graph_and_eager_modes
197  def testZip(self):
198    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
199    dataset2 = readers.TextLineDataset(self._createTextFiles())
200
201    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
202    dataset = input_ops.auto_shard_dataset(
203        dataset, self._num_shards, self._shard_index)
204
205    record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
206    self._verifySimpleShardingOutput(dataset, record_fn)
207
208  @test_util.run_in_graph_and_eager_modes
209  def testConcat(self):
210    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
211    dataset2 = readers.TextLineDataset(self._createTextFiles())
212
213    dataset = dataset1.concatenate(dataset2)
214    dataset = input_ops.auto_shard_dataset(
215        dataset, self._num_shards, self._shard_index)
216
217    next_element_fn = self._getNext(dataset)
218    for f in range(self._shard_index, self._num_files, self._num_shards):
219      for r in range(self._num_records):
220        self.assertAllEqual(
221            self._record(r, f), self.evaluate(next_element_fn()))
222    for f in range(self._shard_index, self._num_files, self._num_shards):
223      for r in range(self._num_records):
224        self.assertAllEqual(
225            self._text_line(r, f), self.evaluate(next_element_fn()))
226    with self.assertRaises(errors.OutOfRangeError):
227      self.evaluate(next_element_fn())
228
229  @test_util.run_in_graph_and_eager_modes
230  def testTextLineReader(self):
231    dataset = readers.TextLineDataset(self._createTextFiles())
232
233    dataset = input_ops.auto_shard_dataset(
234        dataset, self._num_shards, self._shard_index)
235
236    self._verifySimpleShardingOutput(dataset, self._text_line)
237
238  @test_util.run_in_graph_and_eager_modes
239  def testTextLineReaderWithFlatMap(self):
240    dataset = readers.TextLineDataset(self._createTextFiles())
241    dataset = input_ops.auto_shard_dataset(
242        dataset, self._num_shards, self._shard_index)
243
244    self._verifySimpleShardingOutput(dataset, self._text_line)
245
246  @test_util.run_in_graph_and_eager_modes
247  def testFixedLengthReaderWithFlatMap(self):
248    dataset = readers.FixedLengthRecordDataset(
249        self._createFixedLengthRecordFiles(), self._record_bytes)
250    dataset = input_ops.auto_shard_dataset(
251        dataset, self._num_shards, self._shard_index)
252
253    self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
254
255
256# A dataset that creates two variant tensors.
257class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
258
259  def __init__(self, input_dataset):
260    self._input_dataset = input_dataset
261    temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
262        input_dataset._variant_tensor,
263        buffer_size=1,
264        **self._flat_structure)
265    variant_tensor = gen_dataset_ops.model_dataset(
266        temp_variant_tensor, **self._flat_structure)
267    super(_TestDataset, self).__init__(input_dataset, variant_tensor)
268
269
270class CloneDatasetTest(test.TestCase):
271
272  def _assert_datasets_equal(self, ds1, ds2):
273    # First lets assert the structure is the same.
274    self.assertTrue(
275        structure.are_compatible(ds1.element_spec, ds2.element_spec))
276
277    # Now create iterators on both and assert they produce the same values.
278    it1 = dataset_ops.make_initializable_iterator(ds1)
279    it2 = dataset_ops.make_initializable_iterator(ds2)
280
281    get_next1 = it1.get_next()
282    get_next2 = it2.get_next()
283
284    with self.cached_session():
285      self.evaluate([it1.initializer, it2.initializer])
286      val1, val2 = self.evaluate([get_next1, get_next2])
287      self.assertEqual(val1, val2)
288
289  @test_util.run_deprecated_v1
290  def testOnlySource(self):
291    ds = dataset_ops.Dataset.range(10)
292    cloned_ds = input_ops._clone_dataset(ds)
293    self._assert_datasets_equal(ds, cloned_ds)
294
295  @test_util.run_deprecated_v1
296  def testSimplePipeline(self):
297    ds = dataset_ops.Dataset.range(10).map(math_ops.square)
298    cloned_ds = input_ops._clone_dataset(ds)
299    self._assert_datasets_equal(ds, cloned_ds)
300
301  @test_util.run_deprecated_v1
302  def testConcat(self):
303    ds1 = dataset_ops.Dataset.range(10)
304    ds2 = dataset_ops.Dataset.range(10)
305    ds = ds1.concatenate(ds2)
306    cloned_ds = input_ops._clone_dataset(ds)
307    self._assert_datasets_equal(ds, cloned_ds)
308
309  @test_util.run_deprecated_v1
310  def testZip(self):
311    ds1 = dataset_ops.Dataset.range(10)
312    ds2 = dataset_ops.Dataset.range(10)
313    ds = dataset_ops.Dataset.zip((ds1, ds2))
314    cloned_ds = input_ops._clone_dataset(ds)
315    self._assert_datasets_equal(ds, cloned_ds)
316
317  @test_util.run_deprecated_v1
318  def testMultipleVariantTensors(self):
319    ds = dataset_ops.Dataset.range(10)
320    ds = _TestDataset(ds)
321    cloned_ds = input_ops._clone_dataset(ds)
322    self._assert_datasets_equal(ds, cloned_ds)
323
324
325if __name__ == "__main__":
326  test.main()
327