1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for input pipeline modifications for distribution strategies.""" 16 17import os 18 19from tensorflow.python.data.ops import dataset_ops 20from tensorflow.python.data.ops import readers 21from tensorflow.python.data.util import structure 22from tensorflow.python.distribute import input_ops 23from tensorflow.python.eager import context 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import test_util 26from tensorflow.python.lib.io import python_io 27from tensorflow.python.ops import gen_dataset_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.platform import test 30from tensorflow.python.util import compat 31 32 33class AutoShardDatasetTest(test.TestCase): 34 35 def setUp(self): 36 super(AutoShardDatasetTest, self).setUp() 37 self._num_files = 10 38 self._num_records = 4 39 self._num_shards = 2 40 self._shard_index = 0 41 self._record_bytes = 10 42 43 def _getNext(self, dataset): 44 if context.executing_eagerly(): 45 iterator = iter(dataset) 46 return iterator._next_internal # pylint: disable=protected-access 47 else: 48 iterator = dataset_ops.make_one_shot_iterator(dataset) 49 get_next = iterator.get_next() 50 return lambda: get_next 51 52 def _record(self, r, f): 53 return compat.as_bytes("Record %d of file %d" % (r, f)) 54 55 def _text_line(self, r, f): 56 return compat.as_bytes("Text line %d of file %d" % (r, f)) 57 58 def _fixed_length_record(self, r, f): 59 return compat.as_bytes(str((r * f) % 10) * self._record_bytes) 60 61 def _createTFRecordFiles(self): 62 filenames = [] 63 for i in range(self._num_files): 64 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 65 filenames.append(fn) 66 writer = python_io.TFRecordWriter(fn) 67 for j in range(self._num_records): 68 record = self._record(j, i) 69 writer.write(record) 70 writer.close() 71 return filenames 72 73 def _createTextFiles(self): 74 filenames = [] 75 for i in range(self._num_files): 76 fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) 77 filenames.append(fn) 78 contents = [] 79 for j in range(self._num_records): 80 contents.append(self._text_line(j, i)) 81 if j + 1 != self._num_records or i == 0: 82 contents.append(b"\r\n") 83 contents = b"".join(contents) 84 85 with open(fn, "wb") as f: 86 f.write(contents) 87 return filenames 88 89 def _createFixedLengthRecordFiles(self): 90 filenames = [] 91 for i in range(self._num_files): 92 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 93 filenames.append(fn) 94 with open(fn, "wb") as f: 95 for j in range(self._num_records): 96 f.write(self._fixed_length_record(j, i)) 97 return filenames 98 99 def _verifySimpleShardingOutput(self, dataset, record_fn): 100 next_element_fn = self._getNext(dataset) 101 with self.cached_session(): 102 for f in range(self._shard_index, self._num_files, self._num_shards): 103 for r in range(self._num_records): 104 self.assertAllEqual(record_fn(r, f), self.evaluate(next_element_fn())) 105 with self.assertRaises(errors.OutOfRangeError): 106 self.evaluate(next_element_fn()) 107 108 @test_util.run_in_graph_and_eager_modes 109 def testTFRecordDataset(self): 110 dataset = readers.TFRecordDataset(self._createTFRecordFiles()) 111 dataset = input_ops.auto_shard_dataset( 112 dataset, self._num_shards, self._shard_index) 113 114 self._verifySimpleShardingOutput(dataset, self._record) 115 116 @test_util.run_in_graph_and_eager_modes 117 def testFlatMap(self): 118 dataset = dataset_ops.Dataset.from_tensor_slices( 119 self._createTFRecordFiles()) 120 dataset = dataset.flat_map(readers.TFRecordDataset) 121 dataset = input_ops.auto_shard_dataset( 122 dataset, self._num_shards, self._shard_index) 123 124 self._verifySimpleShardingOutput(dataset, self._record) 125 126 @test_util.run_in_graph_and_eager_modes 127 def testInterleave(self): 128 dataset = dataset_ops.Dataset.from_tensor_slices( 129 self._createTFRecordFiles()) 130 dataset = dataset.interleave( 131 readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) 132 dataset = input_ops.auto_shard_dataset( 133 dataset, self._num_shards, self._shard_index) 134 135 # Since block_length == num records in each file, the output will still 136 # contain records in order of files. 137 self._verifySimpleShardingOutput(dataset, self._record) 138 139 @test_util.run_in_graph_and_eager_modes 140 def testListfiles(self): 141 filenames = self._createTFRecordFiles() 142 file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt" 143 dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) 144 dataset = dataset.flat_map(readers.TFRecordDataset) 145 dataset = input_ops.auto_shard_dataset( 146 dataset, self._num_shards, self._shard_index) 147 148 next_element_fn = self._getNext(dataset) 149 actual, expected = [], [] 150 for f in range(self._shard_index, self._num_files, self._num_shards): 151 for r in range(self._num_records): 152 actual.append(self.evaluate(next_element_fn())) 153 expected.append(self._record(r, f)) 154 with self.assertRaises(errors.OutOfRangeError): 155 self.evaluate(next_element_fn()) 156 self.assertAllEqual(expected, actual) 157 158 @test_util.run_in_graph_and_eager_modes 159 def testComplexPipeline(self): 160 # Setup a complex input pipeline. 161 batch_size = 2 162 num_epochs = 5 163 dataset = dataset_ops.Dataset.from_tensor_slices( 164 self._createTFRecordFiles()) 165 dataset = dataset.shuffle(buffer_size=self._num_files) 166 dataset = dataset.flat_map(readers.TFRecordDataset) 167 dataset = dataset.prefetch(buffer_size=batch_size) 168 dataset = dataset.shuffle(2 * self._num_files * self._num_records) 169 dataset = dataset.repeat(num_epochs) 170 dataset = dataset.map(lambda x: x) 171 dataset = dataset.batch(batch_size) 172 dataset = dataset.prefetch(buffer_size=None) 173 174 # Auto shard. 175 dataset = input_ops.auto_shard_dataset( 176 dataset, self._num_shards, self._shard_index) 177 178 # Verify output. 179 next_element_fn = self._getNext(dataset) 180 actual = [] 181 num_iterations = (self._num_files * self._num_records * num_epochs) // ( 182 self._num_shards * batch_size) 183 for _ in range(num_iterations): 184 actual.extend(self.evaluate(next_element_fn())) 185 with self.assertRaises(errors.OutOfRangeError): 186 self.evaluate(next_element_fn()) 187 188 expected = [] 189 for f in range(0, self._num_files, self._num_shards): 190 for r in range(self._num_records): 191 expected.append(self._record(r, f)) 192 expected *= num_epochs 193 194 self.assertAllEqual(sorted(expected), sorted(actual)) 195 196 @test_util.run_in_graph_and_eager_modes 197 def testZip(self): 198 dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) 199 dataset2 = readers.TextLineDataset(self._createTextFiles()) 200 201 dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) 202 dataset = input_ops.auto_shard_dataset( 203 dataset, self._num_shards, self._shard_index) 204 205 record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) 206 self._verifySimpleShardingOutput(dataset, record_fn) 207 208 @test_util.run_in_graph_and_eager_modes 209 def testConcat(self): 210 dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) 211 dataset2 = readers.TextLineDataset(self._createTextFiles()) 212 213 dataset = dataset1.concatenate(dataset2) 214 dataset = input_ops.auto_shard_dataset( 215 dataset, self._num_shards, self._shard_index) 216 217 next_element_fn = self._getNext(dataset) 218 for f in range(self._shard_index, self._num_files, self._num_shards): 219 for r in range(self._num_records): 220 self.assertAllEqual( 221 self._record(r, f), self.evaluate(next_element_fn())) 222 for f in range(self._shard_index, self._num_files, self._num_shards): 223 for r in range(self._num_records): 224 self.assertAllEqual( 225 self._text_line(r, f), self.evaluate(next_element_fn())) 226 with self.assertRaises(errors.OutOfRangeError): 227 self.evaluate(next_element_fn()) 228 229 @test_util.run_in_graph_and_eager_modes 230 def testTextLineReader(self): 231 dataset = readers.TextLineDataset(self._createTextFiles()) 232 233 dataset = input_ops.auto_shard_dataset( 234 dataset, self._num_shards, self._shard_index) 235 236 self._verifySimpleShardingOutput(dataset, self._text_line) 237 238 @test_util.run_in_graph_and_eager_modes 239 def testTextLineReaderWithFlatMap(self): 240 dataset = readers.TextLineDataset(self._createTextFiles()) 241 dataset = input_ops.auto_shard_dataset( 242 dataset, self._num_shards, self._shard_index) 243 244 self._verifySimpleShardingOutput(dataset, self._text_line) 245 246 @test_util.run_in_graph_and_eager_modes 247 def testFixedLengthReaderWithFlatMap(self): 248 dataset = readers.FixedLengthRecordDataset( 249 self._createFixedLengthRecordFiles(), self._record_bytes) 250 dataset = input_ops.auto_shard_dataset( 251 dataset, self._num_shards, self._shard_index) 252 253 self._verifySimpleShardingOutput(dataset, self._fixed_length_record) 254 255 256# A dataset that creates two variant tensors. 257class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset): 258 259 def __init__(self, input_dataset): 260 self._input_dataset = input_dataset 261 temp_variant_tensor = gen_dataset_ops.prefetch_dataset( 262 input_dataset._variant_tensor, 263 buffer_size=1, 264 **self._flat_structure) 265 variant_tensor = gen_dataset_ops.model_dataset( 266 temp_variant_tensor, **self._flat_structure) 267 super(_TestDataset, self).__init__(input_dataset, variant_tensor) 268 269 270class CloneDatasetTest(test.TestCase): 271 272 def _assert_datasets_equal(self, ds1, ds2): 273 # First lets assert the structure is the same. 274 self.assertTrue( 275 structure.are_compatible(ds1.element_spec, ds2.element_spec)) 276 277 # Now create iterators on both and assert they produce the same values. 278 it1 = dataset_ops.make_initializable_iterator(ds1) 279 it2 = dataset_ops.make_initializable_iterator(ds2) 280 281 get_next1 = it1.get_next() 282 get_next2 = it2.get_next() 283 284 with self.cached_session(): 285 self.evaluate([it1.initializer, it2.initializer]) 286 val1, val2 = self.evaluate([get_next1, get_next2]) 287 self.assertEqual(val1, val2) 288 289 @test_util.run_deprecated_v1 290 def testOnlySource(self): 291 ds = dataset_ops.Dataset.range(10) 292 cloned_ds = input_ops._clone_dataset(ds) 293 self._assert_datasets_equal(ds, cloned_ds) 294 295 @test_util.run_deprecated_v1 296 def testSimplePipeline(self): 297 ds = dataset_ops.Dataset.range(10).map(math_ops.square) 298 cloned_ds = input_ops._clone_dataset(ds) 299 self._assert_datasets_equal(ds, cloned_ds) 300 301 @test_util.run_deprecated_v1 302 def testConcat(self): 303 ds1 = dataset_ops.Dataset.range(10) 304 ds2 = dataset_ops.Dataset.range(10) 305 ds = ds1.concatenate(ds2) 306 cloned_ds = input_ops._clone_dataset(ds) 307 self._assert_datasets_equal(ds, cloned_ds) 308 309 @test_util.run_deprecated_v1 310 def testZip(self): 311 ds1 = dataset_ops.Dataset.range(10) 312 ds2 = dataset_ops.Dataset.range(10) 313 ds = dataset_ops.Dataset.zip((ds1, ds2)) 314 cloned_ds = input_ops._clone_dataset(ds) 315 self._assert_datasets_equal(ds, cloned_ds) 316 317 @test_util.run_deprecated_v1 318 def testMultipleVariantTensors(self): 319 ds = dataset_ops.Dataset.range(10) 320 ds = _TestDataset(ds) 321 cloned_ds = input_ops._clone_dataset(ds) 322 self._assert_datasets_equal(ds, cloned_ds) 323 324 325if __name__ == "__main__": 326 test.main() 327