1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test DistributionStrategy, ReplicaContext, and supporting APIs.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.autograph.core import converter_testing 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.distribute import combinations 22from tensorflow.python.distribute import distribute_lib 23from tensorflow.python.distribute import distribution_strategy_context as ds_context 24from tensorflow.python.distribute import input_lib 25from tensorflow.python.distribute import reduce_util 26from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 27from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import variable_scope 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import test 35from tensorflow.python.training import server_lib 36from tensorflow.python.util import nest 37 38 39class _TestReplicaContext(distribute_lib.ReplicaContext): 40 41 def merge_call(self, fn, *args, **kwargs): 42 return kwargs["test_arg"] 43 44 45def _get_test_variable(name, synchronization, aggregation): 46 return { 47 "name": name, 48 "synchronization": synchronization, 49 "aggregation": aggregation 50 } 51 52 53def _test_input_fn(input_context): 54 del input_context 55 return dataset_ops.DatasetV2.from_tensors(1.).repeat() 56 57 58class _TestStrategy(distribute_lib.Strategy): 59 60 def __init__(self): 61 super(_TestStrategy, self).__init__(_TestExtended(self)) 62 63 64class _TestExtended(distribute_lib.StrategyExtendedV1): 65 66 def __init__(self, distribute): 67 super(_TestExtended, self).__init__(distribute) 68 worker_device_pairs = [("", ["/device:CPU:0"])] 69 self._input_workers = input_lib.InputWorkers(worker_device_pairs) 70 71 def _call_for_each_replica(self, fn, args, kwargs): 72 with _TestReplicaContext( 73 self._container_strategy(), replica_id_in_sync_group=0): 74 return fn(*args, **kwargs) 75 76 def _create_variable(self, next_creator, **kwargs): 77 return _get_test_variable(kwargs["name"], kwargs["synchronization"], 78 kwargs["aggregation"]) 79 80 def _make_input_fn_iterator( 81 self, 82 input_fn, 83 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 84 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 85 [distribute_lib.InputContext()], 86 self._container_strategy()) 87 88 def _distribute_datasets_from_function(self, dataset_fn, options): 89 return dataset_fn(distribute_lib.InputContext()) 90 91 def _local_results(self, value): 92 return (value,) 93 94 def _reduce_to(self, reduce_op, value, destinations, options): 95 del reduce_op, destinations, options 96 return value 97 98 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 99 initial_loop_values=None): 100 # TODO(tomhennigan) This is missing many things (e.g. ctx.run_op). 101 ctx = input_lib.MultiStepContext() 102 for _ in range(iterations): 103 fn(ctx, iterator.get_next()) 104 return ctx 105 106 def _update(self, var, fn, args, kwargs, group): 107 # The implementations of _update() and _update_non_slot() are identical 108 # except _update() passes `var` as the first argument to `fn()`. 109 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 110 111 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 112 del colocate_with 113 result = fn(*args, **kwargs) 114 if group: 115 return result 116 else: 117 return nest.map_structure(self._unwrap, result) 118 119 def _get_local_replica_id(self, replica_id_in_sync_group): 120 return replica_id_in_sync_group 121 122 123def _assert_in_default_state(t): 124 t.assertIs(ds_context._get_default_replica_context(), 125 ds_context.get_replica_context()) 126 t.assertIs(None, ds_context.get_cross_replica_context()) 127 t.assertFalse(ds_context.in_cross_replica_context()) 128 t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy()) 129 t.assertFalse(ds_context.has_strategy()) 130 131 132def _run_in_and_out_of_scope(unbound_test_method): 133 def wrapper(test_case): 134 dist = _TestStrategy() 135 # Running in the default (replica) scope should be supported. 136 _assert_in_default_state(test_case) 137 unbound_test_method(test_case, dist) 138 # As well as running in the strategy scope. 139 with dist.scope(): 140 unbound_test_method(test_case, dist) 141 _assert_in_default_state(test_case) 142 # When run under a different strategy the test method should fail. 143 another_strategy = _TestStrategy() 144 msg = "Mixing different .*Strategy objects" 145 with test_case.assertRaisesRegex(RuntimeError, msg): 146 with another_strategy.scope(): 147 unbound_test_method(test_case, dist) 148 return wrapper 149 150 151class TestStrategyTest(test.TestCase): 152 153 def testCallForEachReplica(self): 154 _assert_in_default_state(self) 155 dist = _TestStrategy() 156 157 def run_fn(): 158 replica_context = ds_context.get_replica_context() 159 self.assertIsNotNone(replica_context) 160 self.assertIs(None, ds_context.get_cross_replica_context()) 161 self.assertFalse(ds_context.in_cross_replica_context()) 162 self.assertTrue(ds_context.has_strategy()) 163 self.assertIs(dist, ds_context.get_strategy()) 164 self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) 165 expected_value = _get_test_variable( 166 "bar", variable_scope.VariableSynchronization.AUTO, 167 variable_scope.VariableAggregation.NONE) 168 self.assertDictEqual(expected_value, 169 variable_scope.variable(1.0, name="bar")) 170 171 dist.extended.call_for_each_replica(run_fn) 172 with dist.scope(): 173 dist.extended.call_for_each_replica(run_fn) 174 _assert_in_default_state(self) 175 176 def testScope(self): 177 _assert_in_default_state(self) 178 dist = _TestStrategy() 179 with dist.scope(): 180 self.assertIs(None, ds_context.get_replica_context()) 181 self.assertIs(dist, ds_context.get_cross_replica_context()) 182 self.assertTrue(ds_context.in_cross_replica_context()) 183 self.assertTrue(ds_context.has_strategy()) 184 self.assertIs(dist, ds_context.get_strategy()) 185 expected_value = _get_test_variable( 186 "baz", variable_scope.VariableSynchronization.AUTO, 187 variable_scope.VariableAggregation.NONE) 188 self.assertDictEqual(expected_value, 189 variable_scope.variable(1.0, name="baz")) 190 _assert_in_default_state(self) 191 192 def testScopeDeviceNestingError(self): 193 _assert_in_default_state(self) 194 dist = _TestStrategy() 195 # Open a device scope with dist.scope(). 196 dist.extended._default_device = "/device:GPU:0" 197 scope = dist.scope() 198 scope.__enter__() 199 self.assertIs(dist, ds_context.get_strategy()) 200 with ops.device("/device:CPU:0"): 201 with self.assertRaisesRegex(RuntimeError, "Device scope nesting error"): 202 scope.__exit__(None, None, None) 203 scope.__exit__(None, None, None) 204 _assert_in_default_state(self) 205 206 def testScopeVarCreatorNestingError(self): 207 208 def creator(next_creator, **kwargs): 209 return next_creator(**kwargs) 210 211 _assert_in_default_state(self) 212 dist = _TestStrategy() 213 scope = dist.scope() 214 scope.__enter__() 215 self.assertIs(dist, ds_context.get_strategy()) 216 with variable_scope.variable_creator_scope(creator): 217 with self.assertRaisesRegex(RuntimeError, 218 "Variable creator scope nesting error"): 219 scope.__exit__(None, None, None) 220 scope.__exit__(None, None, None) 221 _assert_in_default_state(self) 222 223 def testScopeVarScopeNestingError(self): 224 # We create a new graph here to simplify clean-up, since the error 225 # we are triggering happens in the middle of scope.__exit__() and 226 # leaves us in a weird state. 227 with ops.Graph().as_default(): 228 _assert_in_default_state(self) 229 dist = _TestStrategy() 230 scope = dist.scope() 231 scope.__enter__() 232 self.assertIs(dist, ds_context.get_strategy()) 233 with variable_scope.variable_scope("AA"): 234 with self.assertRaisesRegex(RuntimeError, 235 "Variable scope nesting error"): 236 scope.__exit__(None, None, None) 237 _assert_in_default_state(self) 238 239 def testSettingSynchronizationAndAggregation(self): 240 _assert_in_default_state(self) 241 dist = _TestStrategy() 242 with dist.scope(): 243 expected_value = _get_test_variable( 244 "baz", variable_scope.VariableSynchronization.ON_WRITE, 245 variable_scope.VariableAggregation.MEAN) 246 self.assertDictEqual( 247 expected_value, 248 variable_scope.variable( 249 1.0, 250 name="baz", 251 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 252 aggregation=variable_scope.VariableAggregation.MEAN)) 253 _assert_in_default_state(self) 254 255 def testSetStrategy(self): 256 _assert_in_default_state(self) 257 dist = _TestStrategy() 258 dist2 = _TestStrategy() 259 ds_context.experimental_set_strategy(dist) 260 self.assertIs(None, ds_context.get_replica_context()) 261 self.assertIs(dist, ds_context.get_cross_replica_context()) 262 self.assertTrue(ds_context.in_cross_replica_context()) 263 self.assertTrue(ds_context.has_strategy()) 264 self.assertIs(dist, ds_context.get_strategy()) 265 expected_value = _get_test_variable( 266 "baz", variable_scope.VariableSynchronization.AUTO, 267 variable_scope.VariableAggregation.NONE) 268 self.assertDictEqual(expected_value, 269 variable_scope.variable(1.0, name="baz")) 270 ds_context.experimental_set_strategy(dist2) 271 self.assertIs(dist2, ds_context.get_strategy()) 272 ds_context.experimental_set_strategy(None) 273 _assert_in_default_state(self) 274 275 def testSetStrategyInScope(self): 276 _assert_in_default_state(self) 277 dist = _TestStrategy() 278 with dist.scope(): 279 with self.assertRaisesRegex( 280 RuntimeError, 281 "Must not be called inside a `tf.distribute.Strategy` scope"): 282 ds_context.experimental_set_strategy(_TestStrategy()) 283 with self.assertRaisesRegex( 284 RuntimeError, 285 "Must not be called inside a `tf.distribute.Strategy` scope"): 286 ds_context.experimental_set_strategy(dist) 287 with self.assertRaisesRegex( 288 RuntimeError, 289 "Must not be called inside a `tf.distribute.Strategy` scope"): 290 ds_context.experimental_set_strategy(None) 291 _assert_in_default_state(self) 292 293 def testSameScopeNesting(self): 294 _assert_in_default_state(self) 295 dist = _TestStrategy() 296 scope_a = dist.scope() 297 with scope_a: 298 self.assertIs(dist, ds_context.get_strategy()) 299 scope_b = dist.scope() 300 with scope_b: 301 self.assertIs(dist, ds_context.get_strategy()) 302 with scope_a: 303 self.assertIs(dist, ds_context.get_strategy()) 304 self.assertIs(dist, ds_context.get_strategy()) 305 self.assertIs(dist, ds_context.get_strategy()) 306 dist2 = _TestStrategy() 307 scope2 = dist2.scope() 308 with self.assertRaisesRegex( 309 RuntimeError, "Mixing different tf.distribute.Strategy objects"): 310 with scope2: 311 pass 312 _assert_in_default_state(self) 313 with scope_b: 314 self.assertIs(dist, ds_context.get_strategy()) 315 _assert_in_default_state(self) 316 317 @_run_in_and_out_of_scope 318 def testMakeInputFnIterator(self, dist): 319 self.assertIsNotNone(dist.make_input_fn_iterator(_test_input_fn)) 320 321 @_run_in_and_out_of_scope 322 def testReduce(self, dist): 323 x = constant_op.constant(1.) 324 x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None) 325 self.assertEqual(self.evaluate(x), self.evaluate(x_r)) 326 327 def testReductions_acceptStringOps(self): 328 dist = _TestStrategy() 329 for op in ("mean", "MEAN", "sum", "SUM"): 330 x = constant_op.constant(1.) 331 y = constant_op.constant(1.) 332 x_r = dist.reduce(op, x, axis=None) 333 self.assertEqual(self.evaluate(x), self.evaluate(x_r)) 334 x_r = dist.extended.reduce_to(op, x, "/CPU:0") 335 self.assertEqual(self.evaluate(x), self.evaluate(x_r)) 336 x_r, y_r = dist.extended.batch_reduce_to(op, 337 ((x, "/CPU:0"), (y, "/CPU:0"))) 338 self.assertEqual(self.evaluate(x), self.evaluate(x_r)) 339 self.assertEqual(self.evaluate(y), self.evaluate(y_r)) 340 341 @_run_in_and_out_of_scope 342 def testReduceMeanAxis(self, dist): 343 x = constant_op.constant([[1., 2.], [3., 4.]]) 344 x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None) 345 self.assertAllEqual(self.evaluate(x), self.evaluate(x_r)) 346 x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=0) 347 self.assertAllEqual([2., 3.], self.evaluate(x_r)) 348 x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=(0, 1)) 349 self.assertEqual(2.5, self.evaluate(x_r)) 350 351 @_run_in_and_out_of_scope 352 def testReduceSumAxis(self, dist): 353 x = constant_op.constant([[1., 2.], [3., 4.]]) 354 x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=None) 355 self.assertAllEqual(self.evaluate(x), self.evaluate(x_r)) 356 x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=0) 357 self.assertAllEqual([4., 6.], self.evaluate(x_r)) 358 x_r = dist.reduce(reduce_util.ReduceOp.SUM, x, axis=(0, 1)) 359 self.assertEqual(10., self.evaluate(x_r)) 360 361 @_run_in_and_out_of_scope 362 def testExperimentalRunStepsOnIterator(self, dist): 363 all_inputs = [] 364 dataset = dataset_ops.Dataset.from_tensors(1.).repeat() 365 dist.extended.experimental_run_steps_on_iterator( 366 lambda _, inputs: all_inputs.append(self.evaluate(inputs)), 367 dataset_ops.make_one_shot_iterator(dataset)) 368 self.assertEqual(all_inputs, [1.]) 369 370 @_run_in_and_out_of_scope 371 def testReduceTo(self, dist): 372 x = constant_op.constant(1.) 373 x_r = dist.extended.reduce_to(reduce_util.ReduceOp.MEAN, x, "/CPU:0") 374 self.assertEqual(self.evaluate(x), self.evaluate(x_r)) 375 376 @_run_in_and_out_of_scope 377 def testBatchReduceTo(self, dist): 378 x = constant_op.constant(1.) 379 y = constant_op.constant(1.) 380 x_r, y_r = dist.extended.batch_reduce_to(reduce_util.ReduceOp.MEAN, 381 ((x, "/CPU:0"), (y, "/CPU:0"))) 382 self.assertEqual(self.evaluate(x), self.evaluate(x_r)) 383 self.assertEqual(self.evaluate(y), self.evaluate(y_r)) 384 385 @_run_in_and_out_of_scope 386 def testUpdate(self, dist): 387 with dist.scope(): 388 v = variables.Variable(1.) 389 t = constant_op.constant(2.) 390 391 def assign_fn(vv, tt): 392 self.assertIs(vv, v) 393 self.assertIs(tt, t) 394 dist.extended.update(v, assign_fn, (t,)) 395 396 @_run_in_and_out_of_scope 397 def testUpdateAutoGraph(self, dist): 398 with dist.scope(): 399 v = variables.Variable(1.) 400 t = constant_op.constant(2.) 401 402 def assign_fn(unused_vv, unused_tt): 403 self.assertTrue(converter_testing.is_inside_generated_code()) 404 405 @def_function.function # AutoGraph is default-on only within tf.function 406 def test_fn(): 407 dist.extended.update(v, assign_fn, (t,)) 408 409 test_fn() 410 411 @_run_in_and_out_of_scope 412 def testUpdateNonSlot(self, dist): 413 t = constant_op.constant(2.) 414 update_calls = [] 415 dist.extended.update_non_slot(t, lambda: update_calls.append(1)) 416 self.assertEqual(len(update_calls), 1) 417 418 @_run_in_and_out_of_scope 419 def testUpdateNonSlotAutoGraph(self, dist): 420 t = constant_op.constant(2.) 421 422 def update_fn(): 423 self.assertTrue(converter_testing.is_inside_generated_code()) 424 425 @def_function.function # AutoGraph is default-on only within tf.function 426 def test_fn(): 427 dist.extended.update_non_slot(t, update_fn) 428 429 test_fn() 430 431 def testClusterResolverDefaultNotImplemented(self): 432 dist = _TestStrategy() 433 self.assertIsNone(dist.cluster_resolver) 434 base_cluster_spec = server_lib.ClusterSpec({ 435 "ps": ["ps0:2222", "ps1:2222"], 436 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] 437 }) 438 cluster_resolver = SimpleClusterResolver(base_cluster_spec) 439 dist.extended._cluster_resolver = cluster_resolver 440 self.assertIs(dist.cluster_resolver, cluster_resolver) 441 442 443# _TestStrategy2 is like _TestStrategy, except it doesn't change variable 444# creation. 445class _TestStrategy2(distribute_lib.Strategy): 446 447 def __init__(self): 448 super(_TestStrategy2, self).__init__(_TestExtended2(self)) 449 450 451class _TestExtended2(_TestExtended): 452 453 def _create_variable(self, next_creator, **kwargs): 454 return next_creator(**kwargs) 455 456 457class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase): 458 459 def testMergeCall(self): 460 _assert_in_default_state(self) 461 462 def merge_fn(dist, s): 463 self.assertIs(ds_context._get_default_strategy(), dist) 464 self.assertIs(None, ds_context.get_replica_context()) 465 self.assertIs(dist, ds_context.get_cross_replica_context()) 466 self.assertTrue(ds_context.in_cross_replica_context()) 467 self.assertIs(dist, ds_context.get_strategy()) 468 self.assertFalse(ds_context.has_strategy()) 469 return "foo_" + s 470 471 replica_ctx = ds_context.get_replica_context() 472 self.assertIs(ds_context._get_default_replica_context(), replica_ctx) 473 self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",))) 474 _assert_in_default_state(self) 475 476 def testMergeCallAutoGraph(self): 477 _assert_in_default_state(self) 478 479 def merge_fn(_, s): 480 self.assertTrue(converter_testing.is_inside_generated_code()) 481 return s 482 483 @def_function.function # AutoGraph is default-on only within tf.function 484 def test_fn(): 485 replica_ctx = ds_context.get_replica_context() 486 replica_ctx.merge_call(merge_fn, args=("bar",)) 487 488 test_fn() 489 490 def testScopeMostlyNoOp(self): 491 _assert_in_default_state(self) 492 493 test_strategy = _TestStrategy2() 494 with test_strategy.scope(): 495 variable_scope.variable(1.0, name="before") 496 497 default_strategy = ds_context._get_default_strategy() 498 scope = default_strategy.scope() 499 with scope: 500 _assert_in_default_state(self) 501 502 with test_strategy.scope(): 503 with self.assertRaisesRegex( 504 RuntimeError, "Mixing different tf.distribute.Strategy objects"): 505 variable_scope.variable(1.0, name="error") 506 507 with scope: 508 _assert_in_default_state(self) 509 510 with test_strategy.scope(): 511 with self.assertRaisesRegex( 512 RuntimeError, "Mixing different tf.distribute.Strategy objects"): 513 variable_scope.variable(1.0, name="also_error") 514 515 _assert_in_default_state(self) 516 517 _assert_in_default_state(self) 518 with test_strategy.scope(): 519 variable_scope.variable(1.0, name="after") 520 521 def testExperimentalRunV2(self): 522 default_strategy = ds_context._get_default_strategy() 523 dataset = dataset_ops.Dataset.range(10).batch(2) 524 iterator = default_strategy.extended._make_dataset_iterator(dataset) 525 next_val = iterator.get_next() 526 527 def train_step(input_data): 528 return input_data 529 530 for _ in range(2): 531 default_strategy.run(train_step, args=(next_val,)) 532 533 @combinations.generate(combinations.combine(mode=["graph", "eager"])) 534 def testDistributedDatasets(self): 535 default_strategy = ds_context._get_default_strategy() 536 if context.executing_eagerly(): 537 dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) 538 dist_dataset = default_strategy.experimental_distribute_dataset( 539 dataset_fn(distribute_lib.InputContext())) 540 next_val = next(iter(dist_dataset)) 541 else: 542 dataset_fn = lambda _: dataset_ops.DatasetV1.range(10).batch(2) 543 dist_dataset = default_strategy.experimental_distribute_dataset( 544 dataset_fn(distribute_lib.InputContext())) 545 iterator = dist_dataset.make_initializable_iterator() 546 self.evaluate(iterator.initializer) 547 next_val = iterator.get_next() 548 self.assertAllEqual([0, 1], self.evaluate(next_val)) 549 550 @combinations.generate(combinations.combine(mode=["graph", "eager"])) 551 def testDistributedDatasetsFromFunction(self): 552 default_strategy = ds_context._get_default_strategy() 553 if context.executing_eagerly(): 554 dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) 555 dist_dataset_from_func = \ 556 default_strategy.distribute_datasets_from_function( 557 dataset_fn) 558 next_val = next(iter(dist_dataset_from_func)) 559 self.assertAllEqual([0, 1], self.evaluate(next_val)) 560 else: 561 dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) 562 dist_dataset_from_func = \ 563 default_strategy.distribute_datasets_from_function( 564 dataset_fn) 565 dataset_ops.make_initializable_iterator(dist_dataset_from_func) 566 567 @combinations.generate(combinations.combine(tf_api_version=1)) 568 def testV1(self): 569 self.assertIsInstance(ds_context.get_strategy(), distribute_lib.StrategyV1) 570 571 @combinations.generate(combinations.combine(tf_api_version=2)) 572 def testV2(self): 573 self.assertIsInstance(ds_context.get_strategy(), distribute_lib.Strategy) 574 575 576class InputContextTest(test.TestCase): 577 578 def testProperties(self): 579 input_context = distribute_lib.InputContext( 580 num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6) 581 self.assertEqual(6, input_context.num_replicas_in_sync) 582 self.assertEqual(1, input_context.input_pipeline_id) 583 self.assertEqual(2, input_context.num_input_pipelines) 584 585 def testPerReplicaBatchSize(self): 586 input_context = distribute_lib.InputContext( 587 num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6) 588 self.assertEqual(2, input_context.get_per_replica_batch_size(12)) 589 with self.assertRaises(ValueError): 590 input_context.get_per_replica_batch_size(13) 591 592 def testStr(self): 593 input_context = distribute_lib.InputContext( 594 num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=42) 595 self.assertEqual( 596 "tf.distribute.InputContext(input pipeline id 0, total: 1)", 597 str(input_context)) 598 input_context = distribute_lib.InputContext( 599 num_input_pipelines=3, input_pipeline_id=1, num_replicas_in_sync=42) 600 self.assertEqual( 601 "tf.distribute.InputContext(input pipeline id 1, total: 3)", 602 str(input_context)) 603 604 605if __name__ == "__main__": 606 test.main() 607