xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/distributed_table_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for distributed_table."""
16
17import copy
18import os
19
20from absl.testing import parameterized
21
22from tensorflow.python.compat import v2_compat
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.distribute import combinations
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import multi_process_runner
27from tensorflow.python.distribute import multi_worker_test_base
28from tensorflow.python.distribute import parameter_server_strategy_v2
29from tensorflow.python.distribute import ps_values
30from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
31from tensorflow.python.distribute.coordinator import coordinator_context
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import test
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import tensor_spec
37from tensorflow.python.keras.saving import save as keras_save
38from tensorflow.python.module import module
39from tensorflow.python.ops import lookup_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.saved_model import load as tf_load
42from tensorflow.python.saved_model import save as tf_save
43
44
45source_combination = combinations.combine(source=["textfile", "keyvaluetensor"])
46
47source_and_load_combination = combinations.combine(
48    source=["textfile", "keyvaluetensor"], load=["tf_load", "keras_load"])
49
50
51class DistributedTableTest(test.TestCase, parameterized.TestCase):
52
53  @classmethod
54  def setUpClass(cls):
55    super(DistributedTableTest, cls).setUpClass()
56    cls.cluster = multi_worker_test_base.create_multi_process_cluster(
57        num_workers=2, num_ps=3, rpc_layer="grpc")
58    cls.cluster_resolver = cls.cluster.cluster_resolver
59
60  @classmethod
61  def tearDownClass(cls):
62    super(DistributedTableTest, cls).tearDownClass()
63    cls.cluster.stop()
64
65  def make_initializer(self, init_source, vals):
66    if init_source == "textfile":
67      file = os.path.join(self.get_temp_dir(), "text_file_initializer")
68      with open(file, "w") as f:
69        f.write("\n".join(str(v) for v in vals) + "\n")
70      return lookup_ops.TextFileInitializer(
71          filename=file,
72          key_dtype=dtypes.int64,
73          key_index=lookup_ops.TextFileIndex.LINE_NUMBER,
74          value_dtype=dtypes.int64,
75          value_index=lookup_ops.TextFileIndex.WHOLE_LINE)
76    elif init_source == "keyvaluetensor":
77      keys_tensor = constant_op.constant(
78          list(range(len(vals))), dtype=dtypes.int64)
79      vals_tensor = constant_op.constant(vals)
80      return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor)
81    else:
82      raise ValueError("Unrecognized init_source: " + init_source)
83
84  def createStaticHashTable(self,
85                            init_source=None,
86                            vals=None,
87                            default_value=None,
88                            initializer=None):
89    if not initializer:
90      initializer = self.make_initializer(init_source, vals)
91    return lookup_ops.StaticHashTable(
92        initializer=initializer, default_value=default_value)
93
94  def makeDatasetFromTensorWithoutUsingResource(self, input_context, tensor):
95    """Returns a dataset made from `tensor`. To be called in a dataset_fn."""
96    global_batch_size = 24
97    batch_size = input_context.get_per_replica_batch_size(global_batch_size)
98    dataset = dataset_ops.DatasetV2.from_tensors(tensor).repeat().batch(
99        batch_size, drop_remainder=True)
100    dataset = dataset.shard(input_context.num_input_pipelines,
101                            input_context.input_pipeline_id)
102    dataset = dataset.prefetch(2)  # This prefetches 2 batches per device.
103    return dataset
104
105  @combinations.generate(source_combination)
106  def testCreateDistributedTableInScope(self, source):
107    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
108        self.cluster_resolver)
109
110    coordinator_lib.ClusterCoordinator(strategy=strategy)
111
112    with strategy.scope():
113      lookuptable = self.createStaticHashTable(
114          init_source=source, vals=[0, 1, 2], default_value=-2)
115
116    self.assertIsInstance(lookuptable, ps_values.DistributedTable)
117    self.assertEqual(self.evaluate(lookuptable.size()), 3)
118
119    # Lookup on the coordinator.
120    output = lookuptable.lookup(
121        constant_op.constant([0, 1, -1], dtype=dtypes.int64))
122    self.assertAllEqual([0, 1, -2], output)
123    self.assertEqual(lookuptable.size(), 3)
124
125  @combinations.generate(source_combination)
126  def testCopyDistributedTable(self, source):
127    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
128        self.cluster_resolver)
129
130    coordinator_lib.ClusterCoordinator(strategy=strategy)
131
132    with strategy.scope():
133      lookuptable = self.createStaticHashTable(
134          init_source=source, vals=[0, 1, 2], default_value=-2)
135
136    new_table = copy.copy(lookuptable)
137    # No new coordinator instance or distributed tables are created.
138    self.assertDictEqual(lookuptable.__dict__, new_table.__dict__)
139
140  @combinations.generate(source_combination)
141  def testCreateLookupInDatasetFnUnderScope(self, source):
142    # TODO(wxinyi): Warn the user of the inefficiency of this workflow (i.e.
143    # creating `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to
144    # be distributed with `distribute_datasets_from_function` and
145    # `create_per_worker_dataset`.
146    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
147        self.cluster_resolver)
148
149    coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
150
151    with strategy.scope():
152
153      def dataset_fn(input_context):
154        some_out_of_range_tensor = constant_op.constant(10, dtype=dtypes.int64)
155        lookuptable = self.createStaticHashTable(
156            init_source=source, vals=[0, 1, 2], default_value=-2)
157
158        self.assertNotIsInstance(lookuptable, ps_values.DistributedTable)
159
160        generation_tensor = lookuptable.lookup(some_out_of_range_tensor)
161        dataset = self.makeDatasetFromTensorWithoutUsingResource(
162            input_context, generation_tensor)
163        return dataset
164
165      @def_function.function
166      def per_worker_dataset_fn():
167        return strategy.distribute_datasets_from_function(dataset_fn)
168
169      per_worker_dataset = coordinator.create_per_worker_dataset(
170          per_worker_dataset_fn)
171      per_worker_iterator = iter(per_worker_dataset)
172
173      @def_function.function
174      def worker_fn(iterator):
175        return math_ops.reduce_sum(next(iterator))
176
177      result = []
178      for _ in range(10):
179        result.append(
180            coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
181
182      for r in result:
183        returned_input = r.fetch()
184        self.assertAllClose(-48, returned_input)
185
186  @combinations.generate(source_combination)
187  def testAccessingResourceHandleInDatasetFnWithoutMap(self, source):
188
189    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
190        self.cluster_resolver)
191
192    coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
193
194    with strategy.scope():
195      lookuptable = self.createStaticHashTable(
196          init_source=source, vals=[0, 1, 2], default_value=-2)
197
198    def dataset_fn(input_context):
199      some_out_of_range_tensor = constant_op.constant(10, dtype=dtypes.int64)
200
201      self.assertIsInstance(lookuptable, ps_values.DistributedTable)
202
203      generation_tensor = lookuptable.lookup(some_out_of_range_tensor)
204      dataset = self.makeDatasetFromTensorWithoutUsingResource(
205          input_context, generation_tensor)
206      return dataset
207
208    @def_function.function
209    def per_worker_dataset_fn():
210      return strategy.distribute_datasets_from_function(dataset_fn)
211
212    per_worker_dataset = coordinator.create_per_worker_dataset(
213        per_worker_dataset_fn)
214    per_worker_iterator = iter(per_worker_dataset)
215
216    @def_function.function
217    def worker_fn(iterator):
218      return math_ops.reduce_sum(next(iterator))
219
220    result = []
221    for _ in range(10):
222      result.append(
223          coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
224
225    for r in result:
226      returned_input = r.fetch()
227      self.assertAllClose(-48, returned_input)
228
229  @combinations.generate(source_combination)
230  def testAccessingResourceHandleInDatasetFnWithMapFnDefinedInside(
231      self, source):
232
233    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
234        self.cluster_resolver)
235
236    coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
237
238    with strategy.scope():
239      lookuptable = self.createStaticHashTable(
240          init_source=source, vals=[0, 1, 2], default_value=-2)
241
242    def dataset_fn(input_context):
243      generation_tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64)
244      dataset = self.makeDatasetFromTensorWithoutUsingResource(
245          input_context, generation_tensor)
246      dataset = dataset.map(lookuptable.lookup)
247      return dataset
248
249    @def_function.function
250    def per_worker_dataset_fn():
251      return strategy.distribute_datasets_from_function(dataset_fn)
252
253    per_worker_dataset = coordinator.create_per_worker_dataset(
254        per_worker_dataset_fn)
255    per_worker_iterator = iter(per_worker_dataset)
256
257    @def_function.function
258    def worker_fn(iterator):
259      return math_ops.reduce_sum(next(iterator))
260
261    result = []
262    for _ in range(10):
263      # batch_size == 24 and each input is [0, 1, -2]
264      result.append(
265          coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
266
267    for r in result:
268      returned_input = r.fetch()
269      self.assertAllClose(-24, returned_input)
270
271  @combinations.generate(source_combination)
272  def testAccessingResourceHandleInDatasetFnWithMapFnDefinedOutside(
273      self, source):
274    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
275        self.cluster_resolver)
276
277    coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy)
278
279    with strategy.scope():
280      lookuptable = self.createStaticHashTable(
281          init_source=source, vals=[0, 1, 2], default_value=-2)
282
283    def map_fn(vals):
284      return lookuptable.lookup(vals)
285
286    def dataset_fn(input_context):
287      generation_tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64)
288      dataset = self.makeDatasetFromTensorWithoutUsingResource(
289          input_context, generation_tensor)
290      dataset = dataset.map(map_fn)
291      return dataset
292
293    @def_function.function
294    def per_worker_dataset_fn():
295      return strategy.distribute_datasets_from_function(dataset_fn)
296
297    per_worker_dataset = coordinator.create_per_worker_dataset(
298        per_worker_dataset_fn)
299    per_worker_iterator = iter(per_worker_dataset)
300
301    @def_function.function
302    def worker_fn(iterator):
303      return math_ops.reduce_sum(next(iterator))
304
305    result = []
306    for _ in range(10):
307      # batch_size == 24 and each input is [0, 1, -2]
308      result.append(
309          coordinator.schedule(worker_fn, args=(per_worker_iterator,)))
310
311    for r in result:
312      returned_input = r.fetch()
313      self.assertAllClose(-24, returned_input)
314
315  class Model(module.Module):
316
317    def __init__(self, init_source, filepath):
318      vals = [0, 1, 2]
319      if init_source == "textfile":
320
321        with open(filepath, "w") as f:
322          f.write("\n".join(str(v) for v in vals) + "\n")
323
324        self.initializer = lookup_ops.TextFileInitializer(
325            filepath, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
326            dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
327      else:
328        keys_tensor = constant_op.constant(
329            list(range(len(vals))), dtype=dtypes.int64)
330        vals_tensor = constant_op.constant(vals)
331        self.initializer = lookup_ops.KeyValueTensorInitializer(
332            keys_tensor, vals_tensor)
333
334      self.table = lookup_ops.StaticHashTable(
335          self.initializer, default_value=-2)
336
337    @def_function.function(
338        input_signature=[tensor_spec.TensorSpec(None, dtypes.int64)])
339    def use_table(self, x):
340      return self.table.lookup(x)
341
342  def verifyWorkerLocalInstance(self, coordinator, model):
343    # assert capturing a worker-local resource on each worker
344    for worker in coordinator._cluster.workers:
345      with coordinator_context.with_dispatch_context(worker):
346        captures = model.use_table.get_concrete_function().captured_inputs
347        resource_capture = [t for t in captures if t.dtype == dtypes.resource]
348        self.assertNotEmpty(resource_capture)
349        for capture in resource_capture:
350          self.assertEqual(
351              capture.device,
352              device_util.canonicalize("/CPU:0", default=worker.device_name))
353
354  @combinations.generate(source_combination)
355  def testInModelAndCapture(self, source):
356
357    file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
358
359    model = self.Model(source, file_path)
360    func_captures = model.use_table.get_concrete_function(
361    ).graph.external_captures
362    self.assertLen(func_captures, 2)
363    self.assertTrue(
364        any(model.table.resource_handle is t for t in func_captures))
365    deferred_captures = model.use_table.get_concrete_function(
366    ).graph.deferred_external_captures
367    self.assertEmpty(deferred_captures)
368
369    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
370        self.cluster_resolver)
371    coordinator = coordinator_lib.ClusterCoordinator(strategy)
372    with strategy.scope():
373      distributed_model = self.Model("value", file_path)
374    func_captures = distributed_model.use_table.get_concrete_function(
375    ).graph.external_captures
376    # One less external_capture, since the table handle becomes a closure in the
377    # deferred_external_capture
378    self.assertLen(func_captures, 1)
379    self.assertFalse(
380        any(model.table.resource_handle is t for t in func_captures))
381    deferred_captures = distributed_model.use_table.get_concrete_function(
382    ).graph.deferred_external_captures
383    self.assertNotEmpty(deferred_captures)
384
385    self.verifyWorkerLocalInstance(coordinator, distributed_model)
386
387  @combinations.generate(source_and_load_combination)
388  def testDistributeTableSaveAndServe(self, load, source):
389    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
390        self.cluster_resolver)
391    file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
392    with strategy.scope():
393      model = self.Model(source, file_path)
394
395    model_dir = self.get_temp_dir()
396    tf_save.save(model, model_dir)
397
398    if load == "tf_load":
399      load_fn = tf_load.load
400    else:
401      load_fn = keras_save.load_model
402
403    loaded_without_strategy = load_fn(model_dir)
404    loaded_func_captures_without_strategy = (
405        loaded_without_strategy.use_table.get_concrete_function().graph
406        .external_captures)
407    loaded_func_deferred_captures_without_strategy = (
408        loaded_without_strategy.use_table.get_concrete_function().graph
409        .deferred_external_captures)
410    self.assertLen(loaded_func_captures_without_strategy, 2)
411    self.assertEmpty(loaded_func_deferred_captures_without_strategy)
412
413    self.assertAllEqual(
414        loaded_without_strategy.use_table(
415            constant_op.constant([0, 1, 3], dtype=dtypes.int64)), [0, 1, -2])
416
417  @combinations.generate(source_and_load_combination)
418  def testDistributeTableSaveAndLoadUnderStrategy(self, load, source):
419    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
420        self.cluster_resolver)
421    coordinator = coordinator_lib.ClusterCoordinator(strategy)
422    file_path = os.path.join(self.get_temp_dir(), "text_file_initializer")
423    with strategy.scope():
424      model = self.Model(source, file_path)
425    model_dir = self.get_temp_dir()
426    tf_save.save(model, model_dir)
427
428    if load == "tf_load":
429      load_fn = tf_load.load
430    else:
431      load_fn = keras_save.load_model
432
433    with strategy.scope():
434      loaded = load_fn(model_dir)
435
436    loaded_func_captures = (
437        loaded.use_table.get_concrete_function().graph.external_captures)
438    loaded_func_deferred_captures = (
439        loaded.use_table.get_concrete_function().graph
440        .deferred_external_captures)
441    # Compared with loading without strategy, there is one less
442    # external_capture, since the captured table handle has been swapped to a
443    # closure in the deferred_external_capture
444    self.assertLen(loaded_func_captures, 1)
445    self.assertNotEmpty(loaded_func_deferred_captures)
446
447    self.assertIsInstance(loaded.table, ps_values.DistributedTable)
448
449    self.assertLen([
450        t for t in loaded.use_table.get_concrete_function().captured_inputs
451        if t.dtype == dtypes.resource
452    ], 1)
453
454    self.verifyWorkerLocalInstance(coordinator, loaded)
455
456
457if __name__ == "__main__":
458  v2_compat.enable_v2_behavior()
459  multi_process_runner.test_main()
460