xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/distribute_lib_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test DistributionStrategy, ReplicaContext, and supporting APIs."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.autograph.core import converter_testing
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.distribute import combinations
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.distribute import distribution_strategy_context as ds_context
24from tensorflow.python.distribute import input_lib
25from tensorflow.python.distribute import reduce_util
26from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
27from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import test
35from tensorflow.python.training import server_lib
36from tensorflow.python.util import nest
37
38
39class _TestReplicaContext(distribute_lib.ReplicaContext):
40
41  def merge_call(self, fn, *args, **kwargs):
42    return kwargs["test_arg"]
43
44
45def _get_test_variable(name, synchronization, aggregation):
46  return {
47      "name": name,
48      "synchronization": synchronization,
49      "aggregation": aggregation
50  }
51
52
53def _test_input_fn(input_context):
54  del input_context
55  return dataset_ops.DatasetV2.from_tensors(1.).repeat()
56
57
58class _TestStrategy(distribute_lib.Strategy):
59
60  def __init__(self):
61    super(_TestStrategy, self).__init__(_TestExtended(self))
62
63
64class _TestExtended(distribute_lib.StrategyExtendedV1):
65
66  def __init__(self, distribute):
67    super(_TestExtended, self).__init__(distribute)
68    worker_device_pairs = [("", ["/device:CPU:0"])]
69    self._input_workers = input_lib.InputWorkers(worker_device_pairs)
70
71  def _call_for_each_replica(self, fn, args, kwargs):
72    with _TestReplicaContext(
73        self._container_strategy(), replica_id_in_sync_group=0):
74      return fn(*args, **kwargs)
75
76  def _create_variable(self, next_creator, **kwargs):
77    return _get_test_variable(kwargs["name"], kwargs["synchronization"],
78                              kwargs["aggregation"])
79
80  def _make_input_fn_iterator(
81      self,
82      input_fn,
83      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
84    return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
85                                              [distribute_lib.InputContext()],
86                                              self._container_strategy())
87
88  def _distribute_datasets_from_function(self, dataset_fn, options):
89    return dataset_fn(distribute_lib.InputContext())
90
91  def _local_results(self, value):
92    return (value,)
93
94  def _reduce_to(self, reduce_op, value, destinations, options):
95    del reduce_op, destinations, options
96    return value
97
98  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
99                                          initial_loop_values=None):
100    # TODO(tomhennigan) This is missing many things (e.g. ctx.run_op).
101    ctx = input_lib.MultiStepContext()
102    for _ in range(iterations):
103      fn(ctx, iterator.get_next())
104    return ctx
105
106  def _update(self, var, fn, args, kwargs, group):
107    # The implementations of _update() and _update_non_slot() are identical
108    # except _update() passes `var` as the first argument to `fn()`.
109    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
110
111  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
112    del colocate_with
113    result = fn(*args, **kwargs)
114    if group:
115      return result
116    else:
117      return nest.map_structure(self._unwrap, result)
118
119  def _get_local_replica_id(self, replica_id_in_sync_group):
120    return replica_id_in_sync_group
121
122
123def _assert_in_default_state(t):
124  t.assertIs(ds_context._get_default_replica_context(),
125             ds_context.get_replica_context())
126  t.assertIs(None, ds_context.get_cross_replica_context())
127  t.assertFalse(ds_context.in_cross_replica_context())
128  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
129  t.assertFalse(ds_context.has_strategy())
130
131
132def _run_in_and_out_of_scope(unbound_test_method):
133  def wrapper(test_case):
134    dist = _TestStrategy()
135    # Running in the default (replica) scope should be supported.
136    _assert_in_default_state(test_case)
137    unbound_test_method(test_case, dist)
138    # As well as running in the strategy scope.
139    with dist.scope():
140      unbound_test_method(test_case, dist)
141    _assert_in_default_state(test_case)
142    # When run under a different strategy the test method should fail.
143    another_strategy = _TestStrategy()
144    msg = "Mixing different .*Strategy objects"
145    with test_case.assertRaisesRegex(RuntimeError, msg):
146      with another_strategy.scope():
147        unbound_test_method(test_case, dist)
148  return wrapper
149
150
151class TestStrategyTest(test.TestCase):
152
153  def testCallForEachReplica(self):
154    _assert_in_default_state(self)
155    dist = _TestStrategy()
156
157    def run_fn():
158      replica_context = ds_context.get_replica_context()
159      self.assertIsNotNone(replica_context)
160      self.assertIs(None, ds_context.get_cross_replica_context())
161      self.assertFalse(ds_context.in_cross_replica_context())
162      self.assertTrue(ds_context.has_strategy())
163      self.assertIs(dist, ds_context.get_strategy())
164      self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
165      expected_value = _get_test_variable(
166          "bar", variable_scope.VariableSynchronization.AUTO,
167          variable_scope.VariableAggregation.NONE)
168      self.assertDictEqual(expected_value,
169                           variable_scope.variable(1.0, name="bar"))
170
171    dist.extended.call_for_each_replica(run_fn)
172    with dist.scope():
173      dist.extended.call_for_each_replica(run_fn)
174    _assert_in_default_state(self)
175
176  def testScope(self):
177    _assert_in_default_state(self)
178    dist = _TestStrategy()
179    with dist.scope():
180      self.assertIs(None, ds_context.get_replica_context())
181      self.assertIs(dist, ds_context.get_cross_replica_context())
182      self.assertTrue(ds_context.in_cross_replica_context())
183      self.assertTrue(ds_context.has_strategy())
184      self.assertIs(dist, ds_context.get_strategy())
185      expected_value = _get_test_variable(
186          "baz", variable_scope.VariableSynchronization.AUTO,
187          variable_scope.VariableAggregation.NONE)
188      self.assertDictEqual(expected_value,
189                           variable_scope.variable(1.0, name="baz"))
190    _assert_in_default_state(self)
191
192  def testScopeDeviceNestingError(self):
193    _assert_in_default_state(self)
194    dist = _TestStrategy()
195    # Open a device scope with dist.scope().
196    dist.extended._default_device = "/device:GPU:0"
197    scope = dist.scope()
198    scope.__enter__()
199    self.assertIs(dist, ds_context.get_strategy())
200    with ops.device("/device:CPU:0"):
201      with self.assertRaisesRegex(RuntimeError, "Device scope nesting error"):
202        scope.__exit__(None, None, None)
203    scope.__exit__(None, None, None)
204    _assert_in_default_state(self)
205
206  def testScopeVarCreatorNestingError(self):
207
208    def creator(next_creator, **kwargs):
209      return next_creator(**kwargs)
210
211    _assert_in_default_state(self)
212    dist = _TestStrategy()
213    scope = dist.scope()
214    scope.__enter__()
215    self.assertIs(dist, ds_context.get_strategy())
216    with variable_scope.variable_creator_scope(creator):
217      with self.assertRaisesRegex(RuntimeError,
218                                  "Variable creator scope nesting error"):
219        scope.__exit__(None, None, None)
220    scope.__exit__(None, None, None)
221    _assert_in_default_state(self)
222
223  def testScopeVarScopeNestingError(self):
224    # We create a new graph here to simplify clean-up, since the error
225    # we are triggering happens in the middle of scope.__exit__() and
226    # leaves us in a weird state.
227    with ops.Graph().as_default():
228      _assert_in_default_state(self)
229      dist = _TestStrategy()
230      scope = dist.scope()
231      scope.__enter__()
232      self.assertIs(dist, ds_context.get_strategy())
233      with variable_scope.variable_scope("AA"):
234        with self.assertRaisesRegex(RuntimeError,
235                                    "Variable scope nesting error"):
236          scope.__exit__(None, None, None)
237    _assert_in_default_state(self)
238
239  def testSettingSynchronizationAndAggregation(self):
240    _assert_in_default_state(self)
241    dist = _TestStrategy()
242    with dist.scope():
243      expected_value = _get_test_variable(
244          "baz", variable_scope.VariableSynchronization.ON_WRITE,
245          variable_scope.VariableAggregation.MEAN)
246      self.assertDictEqual(
247          expected_value,
248          variable_scope.variable(
249              1.0,
250              name="baz",
251              synchronization=variable_scope.VariableSynchronization.ON_WRITE,
252              aggregation=variable_scope.VariableAggregation.MEAN))
253    _assert_in_default_state(self)
254
255  def testSetStrategy(self):
256    _assert_in_default_state(self)
257    dist = _TestStrategy()
258    dist2 = _TestStrategy()
259    ds_context.experimental_set_strategy(dist)
260    self.assertIs(None, ds_context.get_replica_context())
261    self.assertIs(dist, ds_context.get_cross_replica_context())
262    self.assertTrue(ds_context.in_cross_replica_context())
263    self.assertTrue(ds_context.has_strategy())
264    self.assertIs(dist, ds_context.get_strategy())
265    expected_value = _get_test_variable(
266        "baz", variable_scope.VariableSynchronization.AUTO,
267        variable_scope.VariableAggregation.NONE)
268    self.assertDictEqual(expected_value,
269                         variable_scope.variable(1.0, name="baz"))
270    ds_context.experimental_set_strategy(dist2)
271    self.assertIs(dist2, ds_context.get_strategy())
272    ds_context.experimental_set_strategy(None)
273    _assert_in_default_state(self)
274
275  def testSetStrategyInScope(self):
276    _assert_in_default_state(self)
277    dist = _TestStrategy()
278    with dist.scope():
279      with self.assertRaisesRegex(
280          RuntimeError,
281          "Must not be called inside a `tf.distribute.Strategy` scope"):
282        ds_context.experimental_set_strategy(_TestStrategy())
283      with self.assertRaisesRegex(
284          RuntimeError,
285          "Must not be called inside a `tf.distribute.Strategy` scope"):
286        ds_context.experimental_set_strategy(dist)
287      with self.assertRaisesRegex(
288          RuntimeError,
289          "Must not be called inside a `tf.distribute.Strategy` scope"):
290        ds_context.experimental_set_strategy(None)
291    _assert_in_default_state(self)
292
293  def testSameScopeNesting(self):
294    _assert_in_default_state(self)
295    dist = _TestStrategy()
296    scope_a = dist.scope()
297    with scope_a:
298      self.assertIs(dist, ds_context.get_strategy())
299      scope_b = dist.scope()
300      with scope_b:
301        self.assertIs(dist, ds_context.get_strategy())
302        with scope_a:
303          self.assertIs(dist, ds_context.get_strategy())
304        self.assertIs(dist, ds_context.get_strategy())
305      self.assertIs(dist, ds_context.get_strategy())
306      dist2 = _TestStrategy()
307      scope2 = dist2.scope()
308      with self.assertRaisesRegex(
309          RuntimeError, "Mixing different tf.distribute.Strategy objects"):
310        with scope2:
311          pass
312    _assert_in_default_state(self)
313    with scope_b:
314      self.assertIs(dist, ds_context.get_strategy())
315    _assert_in_default_state(self)
316
317  @_run_in_and_out_of_scope
318  def testMakeInputFnIterator(self, dist):
319    self.assertIsNotNone(dist.make_input_fn_iterator(_test_input_fn))
320
321  @_run_in_and_out_of_scope
322  def testReduce(self, dist):
323    x = constant_op.constant(1.)
324    x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None)
325    self.assertEqual(self.evaluate(x), self.evaluate(x_r))
326
327  def testReductions_acceptStringOps(self):
328    dist = _TestStrategy()
329    for op in ("mean", "MEAN", "sum", "SUM"):
330      x = constant_op.constant(1.)
331      y = constant_op.constant(1.)
332      x_r = dist.reduce(op, x, axis=None)
333      self.assertEqual(self.evaluate(x), self.evaluate(x_r))
334      x_r = dist.extended.reduce_to(op, x, "/CPU:0")
335      self.assertEqual(self.evaluate(x), self.evaluate(x_r))
336      x_r, y_r = dist.extended.batch_reduce_to(op,
337                                               ((x, "/CPU:0"), (y, "/CPU:0")))
338      self.assertEqual(self.evaluate(x), self.evaluate(x_r))
339      self.assertEqual(self.evaluate(y), self.evaluate(y_r))
340
341  @_run_in_and_out_of_scope
342  def testReduceMeanAxis(self, dist):
343    x = constant_op.constant([[1., 2.], [3., 4.]])
344    x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None)
345    self.assertAllEqual(self.evaluate(x), self.evaluate(x_r))
346    x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=0)
347    self.assertAllEqual([2., 3.], self.evaluate(x_r))
348    x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=(0, 1))
349    self.assertEqual(2.5, self.evaluate(x_r))
350
351  @_run_in_and_out_of_scope
352  def testReduceSumAxis(self, dist):
353    x = constant_op.constant([[1., 2.], [3., 4.]])
354    x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=None)
355    self.assertAllEqual(self.evaluate(x), self.evaluate(x_r))
356    x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=0)
357    self.assertAllEqual([4., 6.], self.evaluate(x_r))
358    x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=(0, 1))
359    self.assertEqual(10., self.evaluate(x_r))
360
361  @_run_in_and_out_of_scope
362  def testExperimentalRunStepsOnIterator(self, dist):
363    all_inputs = []
364    dataset = dataset_ops.Dataset.from_tensors(1.).repeat()
365    dist.extended.experimental_run_steps_on_iterator(
366        lambda _, inputs: all_inputs.append(self.evaluate(inputs)),
367        dataset_ops.make_one_shot_iterator(dataset))
368    self.assertEqual(all_inputs, [1.])
369
370  @_run_in_and_out_of_scope
371  def testReduceTo(self, dist):
372    x = constant_op.constant(1.)
373    x_r = dist.extended.reduce_to(reduce_util.ReduceOp.MEAN, x, "/CPU:0")
374    self.assertEqual(self.evaluate(x), self.evaluate(x_r))
375
376  @_run_in_and_out_of_scope
377  def testBatchReduceTo(self, dist):
378    x = constant_op.constant(1.)
379    y = constant_op.constant(1.)
380    x_r, y_r = dist.extended.batch_reduce_to(reduce_util.ReduceOp.MEAN,
381                                             ((x, "/CPU:0"), (y, "/CPU:0")))
382    self.assertEqual(self.evaluate(x), self.evaluate(x_r))
383    self.assertEqual(self.evaluate(y), self.evaluate(y_r))
384
385  @_run_in_and_out_of_scope
386  def testUpdate(self, dist):
387    with dist.scope():
388      v = variables.Variable(1.)
389    t = constant_op.constant(2.)
390
391    def assign_fn(vv, tt):
392      self.assertIs(vv, v)
393      self.assertIs(tt, t)
394    dist.extended.update(v, assign_fn, (t,))
395
396  @_run_in_and_out_of_scope
397  def testUpdateAutoGraph(self, dist):
398    with dist.scope():
399      v = variables.Variable(1.)
400    t = constant_op.constant(2.)
401
402    def assign_fn(unused_vv, unused_tt):
403      self.assertTrue(converter_testing.is_inside_generated_code())
404
405    @def_function.function  # AutoGraph is default-on only within tf.function
406    def test_fn():
407      dist.extended.update(v, assign_fn, (t,))
408
409    test_fn()
410
411  @_run_in_and_out_of_scope
412  def testUpdateNonSlot(self, dist):
413    t = constant_op.constant(2.)
414    update_calls = []
415    dist.extended.update_non_slot(t, lambda: update_calls.append(1))
416    self.assertEqual(len(update_calls), 1)
417
418  @_run_in_and_out_of_scope
419  def testUpdateNonSlotAutoGraph(self, dist):
420    t = constant_op.constant(2.)
421
422    def update_fn():
423      self.assertTrue(converter_testing.is_inside_generated_code())
424
425    @def_function.function  # AutoGraph is default-on only within tf.function
426    def test_fn():
427      dist.extended.update_non_slot(t, update_fn)
428
429    test_fn()
430
431  def testClusterResolverDefaultNotImplemented(self):
432    dist = _TestStrategy()
433    self.assertIsNone(dist.cluster_resolver)
434    base_cluster_spec = server_lib.ClusterSpec({
435        "ps": ["ps0:2222", "ps1:2222"],
436        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
437    })
438    cluster_resolver = SimpleClusterResolver(base_cluster_spec)
439    dist.extended._cluster_resolver = cluster_resolver
440    self.assertIs(dist.cluster_resolver, cluster_resolver)
441
442
443# _TestStrategy2 is like _TestStrategy, except it doesn't change variable
444# creation.
445class _TestStrategy2(distribute_lib.Strategy):
446
447  def __init__(self):
448    super(_TestStrategy2, self).__init__(_TestExtended2(self))
449
450
451class _TestExtended2(_TestExtended):
452
453  def _create_variable(self, next_creator, **kwargs):
454    return next_creator(**kwargs)
455
456
457class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
458
459  def testMergeCall(self):
460    _assert_in_default_state(self)
461
462    def merge_fn(dist, s):
463      self.assertIs(ds_context._get_default_strategy(), dist)
464      self.assertIs(None, ds_context.get_replica_context())
465      self.assertIs(dist, ds_context.get_cross_replica_context())
466      self.assertTrue(ds_context.in_cross_replica_context())
467      self.assertIs(dist, ds_context.get_strategy())
468      self.assertFalse(ds_context.has_strategy())
469      return "foo_" + s
470
471    replica_ctx = ds_context.get_replica_context()
472    self.assertIs(ds_context._get_default_replica_context(), replica_ctx)
473    self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",)))
474    _assert_in_default_state(self)
475
476  def testMergeCallAutoGraph(self):
477    _assert_in_default_state(self)
478
479    def merge_fn(_, s):
480      self.assertTrue(converter_testing.is_inside_generated_code())
481      return s
482
483    @def_function.function  # AutoGraph is default-on only within tf.function
484    def test_fn():
485      replica_ctx = ds_context.get_replica_context()
486      replica_ctx.merge_call(merge_fn, args=("bar",))
487
488    test_fn()
489
490  def testScopeMostlyNoOp(self):
491    _assert_in_default_state(self)
492
493    test_strategy = _TestStrategy2()
494    with test_strategy.scope():
495      variable_scope.variable(1.0, name="before")
496
497    default_strategy = ds_context._get_default_strategy()
498    scope = default_strategy.scope()
499    with scope:
500      _assert_in_default_state(self)
501
502      with test_strategy.scope():
503        with self.assertRaisesRegex(
504            RuntimeError, "Mixing different tf.distribute.Strategy objects"):
505          variable_scope.variable(1.0, name="error")
506
507      with scope:
508        _assert_in_default_state(self)
509
510        with test_strategy.scope():
511          with self.assertRaisesRegex(
512              RuntimeError, "Mixing different tf.distribute.Strategy objects"):
513            variable_scope.variable(1.0, name="also_error")
514
515      _assert_in_default_state(self)
516
517    _assert_in_default_state(self)
518    with test_strategy.scope():
519      variable_scope.variable(1.0, name="after")
520
521  def testExperimentalRunV2(self):
522    default_strategy = ds_context._get_default_strategy()
523    dataset = dataset_ops.Dataset.range(10).batch(2)
524    iterator = default_strategy.extended._make_dataset_iterator(dataset)
525    next_val = iterator.get_next()
526
527    def train_step(input_data):
528      return input_data
529
530    for _ in range(2):
531      default_strategy.run(train_step, args=(next_val,))
532
533  @combinations.generate(combinations.combine(mode=["graph", "eager"]))
534  def testDistributedDatasets(self):
535    default_strategy = ds_context._get_default_strategy()
536    if context.executing_eagerly():
537      dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
538      dist_dataset = default_strategy.experimental_distribute_dataset(
539          dataset_fn(distribute_lib.InputContext()))
540      next_val = next(iter(dist_dataset))
541    else:
542      dataset_fn = lambda _: dataset_ops.DatasetV1.range(10).batch(2)
543      dist_dataset = default_strategy.experimental_distribute_dataset(
544          dataset_fn(distribute_lib.InputContext()))
545      iterator = dist_dataset.make_initializable_iterator()
546      self.evaluate(iterator.initializer)
547      next_val = iterator.get_next()
548    self.assertAllEqual([0, 1], self.evaluate(next_val))
549
550  @combinations.generate(combinations.combine(mode=["graph", "eager"]))
551  def testDistributedDatasetsFromFunction(self):
552    default_strategy = ds_context._get_default_strategy()
553    if context.executing_eagerly():
554      dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
555      dist_dataset_from_func = \
556          default_strategy.distribute_datasets_from_function(
557              dataset_fn)
558      next_val = next(iter(dist_dataset_from_func))
559      self.assertAllEqual([0, 1], self.evaluate(next_val))
560    else:
561      dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
562      dist_dataset_from_func = \
563        default_strategy.distribute_datasets_from_function(
564            dataset_fn)
565      dataset_ops.make_initializable_iterator(dist_dataset_from_func)
566
567  @combinations.generate(combinations.combine(tf_api_version=1))
568  def testV1(self):
569    self.assertIsInstance(ds_context.get_strategy(), distribute_lib.StrategyV1)
570
571  @combinations.generate(combinations.combine(tf_api_version=2))
572  def testV2(self):
573    self.assertIsInstance(ds_context.get_strategy(), distribute_lib.Strategy)
574
575
576class InputContextTest(test.TestCase):
577
578  def testProperties(self):
579    input_context = distribute_lib.InputContext(
580        num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
581    self.assertEqual(6, input_context.num_replicas_in_sync)
582    self.assertEqual(1, input_context.input_pipeline_id)
583    self.assertEqual(2, input_context.num_input_pipelines)
584
585  def testPerReplicaBatchSize(self):
586    input_context = distribute_lib.InputContext(
587        num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
588    self.assertEqual(2, input_context.get_per_replica_batch_size(12))
589    with self.assertRaises(ValueError):
590      input_context.get_per_replica_batch_size(13)
591
592  def testStr(self):
593    input_context = distribute_lib.InputContext(
594        num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=42)
595    self.assertEqual(
596        "tf.distribute.InputContext(input pipeline id 0, total: 1)",
597        str(input_context))
598    input_context = distribute_lib.InputContext(
599        num_input_pipelines=3, input_pipeline_id=1, num_replicas_in_sync=42)
600    self.assertEqual(
601        "tf.distribute.InputContext(input pipeline id 1, total: 3)",
602        str(input_context))
603
604
605if __name__ == "__main__":
606  test.main()
607