1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Distribute Coordinator.""" 16 17import contextlib 18import copy 19import json 20import os 21import sys 22import threading 23import time 24 25import six 26 27# pylint: disable=g-import-not-at-top 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.client import session 30from tensorflow.python.distribute import distribute_coordinator 31from tensorflow.python.distribute import distribute_coordinator_context 32from tensorflow.python.framework import errors 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import variable_scope 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import test 40from tensorflow.python.training import coordinator 41from tensorflow.python.training import monitored_session 42from tensorflow.python.training import session_manager 43 44 45CHIEF = distribute_coordinator._TaskType.CHIEF 46WORKER = distribute_coordinator._TaskType.WORKER 47PS = distribute_coordinator._TaskType.PS 48EVALUATOR = distribute_coordinator._TaskType.EVALUATOR 49 50STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT 51INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER 52 53NUM_WORKERS = 3 54NUM_PS = 2 55 56original_sys_exit = sys.exit 57 58 59def _bytes_to_str(maybe_bytes): 60 if isinstance(maybe_bytes, six.string_types): 61 return maybe_bytes 62 else: 63 return str(maybe_bytes, "utf-8") 64 65 66def _strip_protocol(target): 67 # cluster_spec expects "host:port" strings. 68 if "//" in target: 69 return target.split("//")[1] 70 else: 71 return target 72 73 74class MockExtended(object): 75 76 def __init__(self, 77 between_graph=False, 78 should_init=None, 79 should_checkpoint=None, 80 should_save_summary=None): 81 self.experimental_between_graph = between_graph 82 self.experimental_should_init = should_init 83 self.should_checkpoint = should_checkpoint 84 self.should_save_summary = should_save_summary 85 86 87class MockStrategy(object): 88 89 def __init__(self, 90 between_graph=False, 91 should_init=None, 92 should_checkpoint=None, 93 should_save_summary=None): 94 self.extended = MockExtended(between_graph, should_init, should_checkpoint, 95 should_save_summary) 96 97 def configure(self, 98 session_config=None, 99 cluster_spec=None, 100 task_type=None, 101 task_id=None): 102 if self.extended.experimental_should_init is None: 103 if task_id == 0: 104 self.extended.experimental_should_init = True 105 else: 106 self.extended.experimental_should_init = False 107 if self.extended.should_checkpoint is None: 108 if task_id == 0: 109 self.extended.should_checkpoint = True 110 else: 111 self.extended.should_checkpoint = False 112 if self.extended.should_save_summary is None: 113 if task_id == 0: 114 self.extended.should_save_summary = True 115 else: 116 self.extended.should_save_summary = False 117 118 if session_config: 119 if (cluster_spec and task_type and task_id is not None and 120 self.extended.experimental_between_graph): 121 session_config.intra_op_parallelism_threads += 1 122 if task_type in ["chief", "worker"]: 123 session_config.device_filters.extend( 124 ["/job:%s/task:%d" % (task_type, task_id), "/job:ps"]) 125 else: 126 session_config.inter_op_parallelism_threads += 1 127 session_config.device_filters.append("/job:somejob") 128 129 130class MockServer(object): 131 132 def __init__(self): 133 self._joined = False 134 self._started = False 135 136 def start(self): 137 self._started = True 138 139 def join(self): 140 assert not self._joined 141 self._joined = True 142 143 @property 144 def joined(self): 145 return self._joined 146 147 @property 148 def started(self): 149 return self._started 150 151 152class DistributeCoordinatorTestBase(test.TestCase): 153 154 @classmethod 155 def setUpClass(cls): 156 # We have to create a global in-process cluster because once an in-process 157 # tensorflow server is created, there is no way to terminate it. Please see 158 # multi_worker_test_base.py for more details. 159 # TODO(yuefengz): use the utitliy from multi_worker_test_base. 160 cls._workers, cls._ps = test_util.create_local_cluster( 161 NUM_WORKERS, num_ps=NUM_PS) 162 cls._cluster_spec = { 163 WORKER: [ 164 _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers 165 ], 166 PS: [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps] 167 } 168 169 def setUp(self): 170 self._result_correct = 0 171 self._lock = threading.Lock() 172 self._worker_context = {} 173 self._strategy_property = {} 174 self._std_servers = {} 175 self._barrier = distribute_coordinator._Barrier(NUM_WORKERS) 176 self._coord = coordinator.Coordinator() 177 178 @contextlib.contextmanager 179 def _test_session(self, target): 180 config = config_pb2.ConfigProto(allow_soft_placement=True) 181 config.graph_options.optimizer_options.opt_level = -1 182 with session.Session(graph=None, config=config, target=target) as sess: 183 yield sess 184 185 # TODO(yuefengz): use the utitliy from multi_worker_test_base. 186 def _create_cluster_spec(self, 187 has_chief=False, 188 num_workers=1, 189 num_ps=0, 190 has_eval=False): 191 cluster_spec = {} 192 if has_chief: 193 cluster_spec[CHIEF] = ["localhost:%s" % test_util.pick_unused_port()] 194 if num_workers: 195 cluster_spec[WORKER] = [ 196 "localhost:%s" % test_util.pick_unused_port() 197 for _ in range(num_workers) 198 ] 199 if num_ps: 200 cluster_spec[PS] = [ 201 "localhost:%s" % test_util.pick_unused_port() for _ in range(num_ps) 202 ] 203 if has_eval: 204 cluster_spec[EVALUATOR] = ["localhost:%s" % test_util.pick_unused_port()] 205 return cluster_spec 206 207 def _in_graph_worker_fn(self, strategy): 208 context = distribute_coordinator_context.get_current_worker_context() 209 self.assertTrue(context is not None) 210 with self._test_session(target=context.master_target) as sess: 211 xs = [] 212 expected = 0.0 213 for i in range(context.num_workers): 214 with ops.device("/job:worker/task:%d" % i): 215 x = variable_scope.get_variable("x_%d" % i, initializer=10.0) 216 x_add = x.assign_add(float(i)) 217 xs.append(x_add) 218 expected += i + 10.0 219 220 with ops.device("/job:worker/task:0"): 221 result = math_ops.add_n(xs) 222 223 self.evaluate(variables.global_variables_initializer()) 224 result_value = sess.run(result) 225 self.assertEqual(result_value, expected) 226 if result_value == expected: 227 self._result_correct += 1 228 229 def _wrapped_worker_fn(self, worker_fn): 230 def wrapped(*args, **kwargs): 231 with self._coord.stop_on_exception(): 232 return worker_fn(*args, **kwargs) 233 return wrapped 234 235 def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs): 236 t = threading.Thread( 237 target=distribute_coordinator.run_distribute_coordinator, 238 args=(self._wrapped_worker_fn(worker_fn), strategy), 239 kwargs=kwargs) 240 t.start() 241 return t 242 243 def _run_multiple_coordinator_in_threads(self, worker_fn, strategy, 244 cluster_spec, **kwargs): 245 threads = {} 246 for task_type in cluster_spec.keys(): 247 threads[task_type] = [] 248 for task_id in range(len(cluster_spec[task_type])): 249 t = self._run_coordinator_in_thread( 250 worker_fn, 251 strategy, 252 cluster_spec=cluster_spec, 253 task_type=task_type, 254 task_id=task_id, 255 **kwargs) 256 threads[task_type].append(t) 257 return threads 258 259 def _join_threads(self, threads): 260 try: 261 self._coord.join(threads) 262 except errors.UnknownError as e: 263 if "Could not start gRPC server" in e.message: 264 self.skipTest("Cannot start std servers.") 265 else: 266 raise 267 268 def _between_graph_worker_fn(self, strategy): 269 context = distribute_coordinator_context.get_current_worker_context() 270 self.assertTrue(context is not None) 271 with self._test_session(target=context.master_target) as sess: 272 with ops.device("/job:ps/task:0"): 273 # TODO(yuefengz): investigate why not using resource variable will make 274 # the test flaky. 275 x = variable_scope.get_variable( 276 "x", initializer=10.0, use_resource=True) 277 with ops.device("/job:ps/task:1"): 278 y = variable_scope.get_variable( 279 "y", initializer=20.0, use_resource=True) 280 281 x_add = x.assign_add(2.0) 282 y_sub = y.assign_sub(2.0) 283 train_op = control_flow_ops.group([x_add, y_sub]) 284 285 if context.is_chief: 286 self.evaluate(variables.global_variables_initializer()) 287 288 # Synchronize workers after initializaton. 289 if context.has_barrier: 290 context.wait_for_other_workers() 291 else: 292 while True: 293 uninit_vars = sess.run(variables.report_uninitialized_variables()) 294 # pylint: disable=g-explicit-length-test 295 if len(uninit_vars) == 0: 296 break 297 298 sess.run(train_op) 299 300 # Synchronize workers after one step to make sure they all have finished 301 # training. 302 if context.has_barrier: 303 context.wait_for_other_workers() 304 else: 305 self._barrier.wait() 306 307 x_val, y_val = sess.run([x, y]) 308 309 self.assertEqual(x_val, 16.0) 310 self.assertEqual(y_val, 14.0) 311 if x_val == 16.0 and y_val == 14.0: 312 with self._lock: 313 self._result_correct += 1 314 315 def _between_graph_with_monitored_session(self, strategy): 316 context = distribute_coordinator_context.get_current_worker_context() 317 self.assertTrue(context is not None) 318 with ops.device("/job:ps/task:0"): 319 # TODO(yuefengz): investigate why not using resource variable will make 320 # the test flaky. 321 x = variable_scope.get_variable("xx", initializer=10.0, use_resource=True) 322 with ops.device("/job:ps/task:1"): 323 y = variable_scope.get_variable("yy", initializer=20.0, use_resource=True) 324 325 x_add = x.assign_add(2.0) 326 y_sub = y.assign_sub(2.0) 327 train_op = control_flow_ops.group([x_add, y_sub]) 328 329 # The monitored session will run init or ready ops. 330 with monitored_session.MonitoredSession() as sess: 331 sess.run(train_op) 332 333 # Synchronize workers after one step to make sure they all have finished 334 # training. 335 if context.has_barrier: 336 context.wait_for_other_workers() 337 else: 338 self._barrier.wait() 339 340 x_val, y_val = sess.run([x, y]) 341 342 self.assertEqual(x_val, 16.0) 343 self.assertEqual(y_val, 14.0) 344 if x_val == 16.0 and y_val == 14.0: 345 with self._lock: 346 self._result_correct += 1 347 348 def _dump_worker_context(self, strategy): 349 """Dumps the propoerties of each worker context. 350 351 It dumps the context properties to a dict mapping from task_type to a list 352 of tuples of master_target, num_workers, is_chief and distribute_mode, where 353 the list is indexed by the task_id. 354 355 Args: 356 strategy: a `DistributionStrategy` object. 357 """ 358 context = distribute_coordinator_context.get_current_worker_context() 359 self.assertTrue(context is not None) 360 task_type = str(context.task_type) 361 task_id = context.task_id or 0 362 with self._lock: 363 if task_type not in self._worker_context: 364 self._worker_context[task_type] = [] 365 while len(self._worker_context[task_type]) <= task_id: 366 self._worker_context[task_type].append(None) 367 self._worker_context[task_type][task_id] = (context.master_target, 368 context.num_workers, 369 context.is_chief, 370 context.distributed_mode) 371 372 def _dump_strategy_property(self, strategy): 373 context = distribute_coordinator_context.get_current_worker_context() 374 self.assertTrue(context is not None) 375 376 self.assertEqual(context._strategy.extended.experimental_should_init, 377 strategy.extended.experimental_should_init) 378 self.assertEqual(context.should_checkpoint, 379 strategy.extended.should_checkpoint) 380 self.assertEqual(context.should_save_summary, 381 strategy.extended.should_save_summary) 382 383 task_type = str(context.task_type) 384 task_id = context.task_id or 0 385 with self._lock: 386 if task_type not in self._strategy_property: 387 self._strategy_property[task_type] = [] 388 while len(self._strategy_property[task_type]) <= task_id: 389 self._strategy_property[task_type].append(None) 390 self._strategy_property[task_type][task_id] = ( 391 context._strategy.extended.experimental_should_init, 392 context.should_checkpoint, 393 context.should_save_summary) 394 395 def _run_mock_std_server(self, 396 session_config=None, 397 cluster_spec=None, 398 task_type=None, 399 task_id=None, 400 rpc_layer=None, 401 environment=None): 402 task_type = str(task_type) 403 task_id = task_id or 0 404 with self._lock: 405 if task_type not in self._std_servers: 406 self._std_servers[task_type] = [] 407 while len(self._std_servers[task_type]) <= task_id: 408 self._std_servers[task_type].append(None) 409 410 server = MockServer() 411 self._std_servers[task_type][task_id] = server 412 return server 413 414 415class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase): 416 417 def testInGraphStandaloneMode(self): 418 """Test it runs in-graph replication in standalone client mode.""" 419 distribute_coordinator.run_distribute_coordinator( 420 self._in_graph_worker_fn, 421 MockStrategy(between_graph=False), 422 cluster_spec=self._cluster_spec) 423 self.assertEqual(self._result_correct, 1) 424 425 def testBetweenGraph(self): 426 """Test it runs between-graph replication in standalone client mode.""" 427 distribute_coordinator.run_distribute_coordinator( 428 self._between_graph_worker_fn, 429 MockStrategy(between_graph=True), 430 cluster_spec=self._cluster_spec) 431 432 # Each finished worker will increment self._result_correct. 433 self.assertEqual(self._result_correct, NUM_WORKERS) 434 435 @test_util.run_v1_only("MonitoredSession removed from v2") 436 def testBetweenGraphWithMonitoredSession(self): 437 """Test monitored session in standalone client mode.""" 438 distribute_coordinator.run_distribute_coordinator( 439 self._between_graph_with_monitored_session, 440 MockStrategy(between_graph=True), 441 cluster_spec=self._cluster_spec) 442 443 # Each finished worker will increment self._result_correct. 444 self.assertEqual(self._result_correct, NUM_WORKERS) 445 446 def testBetweenGraphContext(self): 447 # Dumps the task contexts to the self._worker_context dict. 448 distribute_coordinator.run_distribute_coordinator( 449 self._dump_worker_context, 450 MockStrategy(between_graph=True), 451 cluster_spec=self._cluster_spec) 452 453 # There is only one type of task and there three such tasks. 454 self.assertEqual(len(self._worker_context), 1) 455 self.assertTrue(WORKER in self._worker_context) 456 self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) 457 458 # Check whether each task has the right master_target, num_workers, is_chief 459 # and distributed_mode. 460 self.assertEqual( 461 self._worker_context[WORKER][0], 462 (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) 463 self.assertEqual( 464 self._worker_context[WORKER][1], 465 (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True)) 466 self.assertEqual( 467 self._worker_context[WORKER][2], 468 (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True)) 469 470 def testBetweenGraphStrategyProperties(self): 471 # Dumps properties of the strategy objects. 472 distribute_coordinator.run_distribute_coordinator( 473 self._dump_strategy_property, 474 MockStrategy(between_graph=True, should_init=True), 475 cluster_spec=self._cluster_spec) 476 477 # There is only one type of task and there three such tasks. 478 self.assertEqual(len(self._strategy_property), 1) 479 self.assertTrue(WORKER in self._strategy_property) 480 self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS) 481 482 # Check whether each task has the right properties of should_init, 483 # should_checkpoint and should_save_summary. 484 self.assertEqual(self._strategy_property[WORKER][0], (True, True, True)) 485 self.assertEqual(self._strategy_property[WORKER][1], (True, False, False)) 486 self.assertEqual(self._strategy_property[WORKER][2], (True, False, False)) 487 488 def testInGraphContext(self): 489 # Dumps the task contexts to the self._worker_context dict. 490 distribute_coordinator.run_distribute_coordinator( 491 self._dump_worker_context, 492 MockStrategy(between_graph=False), 493 cluster_spec=self._cluster_spec) 494 495 # There is only a "None" task in the dumped task context. 496 self.assertEqual(len(self._worker_context), 1) 497 self.assertTrue("None" in self._worker_context) 498 self.assertEqual(len(self._worker_context["None"]), 1) 499 500 # Check whether each task has the right master_target, num_workers, is_chief 501 # and distributed_mode. 502 self.assertEqual( 503 self._worker_context["None"][0], 504 (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) 505 506 def testLocalContext(self): 507 # Dumps the task contexts to the self._worker_context dict. 508 distribute_coordinator.run_distribute_coordinator( 509 self._dump_worker_context, 510 MockStrategy(between_graph=False), 511 cluster_spec=None) 512 513 # There is only a "None" task. 514 self.assertEqual(len(self._worker_context), 1) 515 self.assertTrue("None" in self._worker_context) 516 self.assertEqual(len(self._worker_context["None"]), 1) 517 518 # Check whether each task has the right master_target, num_workers, is_chief 519 # and distributed_mode. 520 self.assertEqual(self._worker_context["None"][0], ("", 0, True, False)) 521 522 def testBetweenGraphContextWithChief(self): 523 # Adds a chief node, so there are NUM_WORKERS + 1 workers in total. 524 cluster_spec = copy.deepcopy(self._cluster_spec) 525 cluster_spec[CHIEF] = ["fake_chief"] 526 527 # Dumps the task contexts to the self._worker_context dict. 528 distribute_coordinator.run_distribute_coordinator( 529 self._dump_worker_context, 530 MockStrategy(between_graph=True), 531 cluster_spec=cluster_spec, 532 rpc_layer="grpc") 533 534 # There are one CHIEF and three workers. 535 self.assertEqual(len(self._worker_context), 2) 536 self.assertTrue(CHIEF in self._worker_context) 537 self.assertTrue(WORKER in self._worker_context) 538 self.assertEqual(len(self._worker_context[CHIEF]), 1) 539 self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) 540 541 # Check whether each task has the right master_target, num_workers, is_chief 542 # and distributed_mode. 543 self.assertEqual(self._worker_context[CHIEF][0], 544 ("grpc://fake_chief", 4, True, True)) 545 self.assertEqual( 546 self._worker_context[WORKER][0], 547 (_bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True)) 548 self.assertEqual( 549 self._worker_context[WORKER][1], 550 (_bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True)) 551 self.assertEqual( 552 self._worker_context[WORKER][2], 553 (_bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True)) 554 555 def testInGraphContextWithEval(self): 556 # Adds a EVALUATOR job. 557 cluster_spec = copy.deepcopy(self._cluster_spec) 558 cluster_spec[EVALUATOR] = ["fake_evaluator"] 559 560 # Dumps the task contexts to the self._worker_context dict. 561 distribute_coordinator.run_distribute_coordinator( 562 self._dump_worker_context, 563 MockStrategy(between_graph=False), 564 cluster_spec=cluster_spec, 565 rpc_layer=None) 566 567 # There are one "None" task and one EVALUATOR task. 568 self.assertEqual(len(self._worker_context), 2) 569 self.assertTrue("None" in self._worker_context) 570 self.assertTrue(EVALUATOR in self._worker_context) 571 self.assertEqual(len(self._worker_context["None"]), 1) 572 self.assertEqual(len(self._worker_context[EVALUATOR]), 1) 573 574 # Check whether each task has the right master_target, num_workers, is_chief 575 # and distributed_mode. 576 self.assertEqual(self._worker_context["None"][0], (_strip_protocol( 577 _bytes_to_str(self._workers[0].target)), 3, True, True)) 578 self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False)) 579 580 581class DistributeCoordinatorTestIndependentWorkerMode( 582 DistributeCoordinatorTestBase): 583 584 def testInGraph(self): 585 cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) 586 threads = self._run_multiple_coordinator_in_threads( 587 self._in_graph_worker_fn, 588 MockStrategy(between_graph=False), 589 cluster_spec, 590 mode=INDEPENDENT_WORKER) 591 self._join_threads([threads[WORKER][0]]) 592 self.assertEqual(self._result_correct, 1) 593 594 def testBetweenGraph(self): 595 cluster_spec = self._create_cluster_spec( 596 num_workers=NUM_WORKERS, num_ps=NUM_PS) 597 threads = self._run_multiple_coordinator_in_threads( 598 self._between_graph_worker_fn, 599 MockStrategy(between_graph=True), 600 cluster_spec, 601 mode=INDEPENDENT_WORKER) 602 self._join_threads(threads[WORKER]) 603 604 # Each finished worker will increment self._result_correct. 605 self.assertEqual(self._result_correct, NUM_WORKERS) 606 607 @test_util.run_v1_only("MonitoredSession removed from v2") 608 def testBetweenGraphWithMonitoredSession(self): 609 cluster_spec = self._create_cluster_spec( 610 num_workers=NUM_WORKERS, num_ps=NUM_PS) 611 threads = self._run_multiple_coordinator_in_threads( 612 self._between_graph_with_monitored_session, 613 MockStrategy(between_graph=True), 614 cluster_spec, 615 mode=INDEPENDENT_WORKER) 616 self._join_threads(threads[WORKER]) 617 618 # Each finished worker will increment self._result_correct. 619 self.assertEqual(self._result_correct, NUM_WORKERS) 620 621 def testBetweenGraphContext(self): 622 cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) 623 # Dumps the task contexts and std server arguments. 624 with test.mock.patch.object(distribute_coordinator, "_run_std_server", 625 self._run_mock_std_server): 626 threads = self._run_multiple_coordinator_in_threads( 627 self._dump_worker_context, 628 MockStrategy(between_graph=True), 629 cluster_spec, 630 mode=INDEPENDENT_WORKER, 631 rpc_layer=None) 632 self._join_threads(threads[WORKER]) 633 634 # There is only one type of task and three such tasks. 635 self.assertEqual(len(self._worker_context), 1) 636 self.assertTrue(WORKER in self._worker_context) 637 self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) 638 639 # Check whether each task has the right master_target, num_workers, is_chief 640 # and distributed_mode. 641 self.assertEqual( 642 self._worker_context[WORKER][0], 643 (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True)) 644 self.assertEqual( 645 self._worker_context[WORKER][1], 646 (_bytes_to_str(cluster_spec[WORKER][1]), NUM_WORKERS, False, True)) 647 self.assertEqual( 648 self._worker_context[WORKER][2], 649 (_bytes_to_str(cluster_spec[WORKER][2]), NUM_WORKERS, False, True)) 650 651 # Make sure each worker runs a std server. 652 self.assertEqual(len(self._std_servers), 1) 653 self.assertTrue(WORKER in self._std_servers) 654 self.assertEqual(len(self._std_servers[WORKER]), 3) 655 self.assertFalse(self._std_servers[WORKER][0].joined) 656 self.assertFalse(self._std_servers[WORKER][1].joined) 657 self.assertFalse(self._std_servers[WORKER][2].joined) 658 659 def testBetweenGraphStrategyProperties(self): 660 cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) 661 # Dumps properties of the strategy objects. 662 with test.mock.patch.object(distribute_coordinator, "_run_std_server", 663 self._run_mock_std_server): 664 threads = self._run_multiple_coordinator_in_threads( 665 self._dump_strategy_property, 666 MockStrategy(between_graph=True, should_init=True), 667 cluster_spec, 668 mode=INDEPENDENT_WORKER, 669 rpc_layer=None) 670 self._join_threads(threads[WORKER]) 671 672 # There is only one type of task and there three such tasks. 673 self.assertEqual(len(self._strategy_property), 1) 674 self.assertTrue(WORKER in self._strategy_property) 675 self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS) 676 677 # Check whether each task has the right properties of should_init, 678 # should_checkpoint and should_save_summary. 679 self.assertEqual(self._strategy_property[WORKER][0], (True, True, True)) 680 self.assertEqual(self._strategy_property[WORKER][1], (True, False, False)) 681 self.assertEqual(self._strategy_property[WORKER][2], (True, False, False)) 682 683 def testInGraphContext(self): 684 cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) 685 # Dumps the task contexts and std server arguments. 686 with test.mock.patch.object(distribute_coordinator, "_run_std_server", 687 self._run_mock_std_server): 688 threads = self._run_multiple_coordinator_in_threads( 689 self._dump_worker_context, 690 MockStrategy(between_graph=False), 691 cluster_spec, 692 mode=INDEPENDENT_WORKER, 693 rpc_layer=None) 694 self._join_threads(threads[WORKER]) 695 696 # There is only a "None" task in the dumped task context. 697 self.assertEqual(len(self._worker_context), 1) 698 self.assertTrue("None" in self._worker_context) 699 self.assertEqual(len(self._worker_context["None"]), 1) 700 701 # Check whether each task has the right master_target, num_workers, is_chief 702 # and distributed_mode. 703 self.assertEqual( 704 self._worker_context["None"][0], 705 (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True)) 706 707 # Make sure each worker runs a std server. 708 self.assertEqual(len(self._std_servers), 1) 709 self.assertTrue(WORKER in self._std_servers) 710 self.assertEqual(len(self._std_servers[WORKER]), 3) 711 self.assertFalse(self._std_servers[WORKER][0].joined) 712 self.assertTrue(self._std_servers[WORKER][1].joined) 713 self.assertTrue(self._std_servers[WORKER][2].joined) 714 715 def testInGraphContextWithEval(self): 716 # Adds a EVALUATOR job. 717 cluster_spec = self._create_cluster_spec( 718 num_workers=NUM_WORKERS, has_eval=True) 719 720 # Dumps the task contexts and std server arguments. 721 with test.mock.patch.object(distribute_coordinator, "_run_std_server", 722 self._run_mock_std_server): 723 threads = self._run_multiple_coordinator_in_threads( 724 self._dump_worker_context, 725 MockStrategy(between_graph=False), 726 cluster_spec, 727 mode=INDEPENDENT_WORKER, 728 rpc_layer=None) 729 self._join_threads(threads[WORKER]) 730 self._join_threads([threads[EVALUATOR][0]]) 731 732 # There are one "None" task and one EVALUATOR task. 733 self.assertEqual(len(self._worker_context), 2) 734 self.assertTrue("None" in self._worker_context) 735 self.assertTrue(EVALUATOR in self._worker_context) 736 self.assertEqual(len(self._worker_context["None"]), 1) 737 self.assertEqual(len(self._worker_context[EVALUATOR]), 1) 738 739 # Check whether each task has the right master_target, num_workers, is_chief 740 # and distributed_mode. 741 self.assertEqual(self._worker_context["None"][0], 742 (_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True)) 743 self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False)) 744 745 # Make sure each worker runs a std server. 746 self.assertEqual(len(self._std_servers), 1) 747 self.assertTrue(WORKER in self._std_servers) 748 self.assertEqual(len(self._std_servers[WORKER]), 3) 749 self.assertFalse(self._std_servers[WORKER][0].joined) 750 self.assertTrue(self._std_servers[WORKER][1].joined) 751 self.assertTrue(self._std_servers[WORKER][2].joined) 752 753 def testRunStdServerInGoogleEnvironment(self): 754 cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]} 755 tf_config = {"cluster": cluster_spec, "environment": "google"} 756 757 joined = [False] 758 759 def _fake_sleep(_): 760 joined[0] = True 761 original_sys_exit(0) 762 763 def _thread_fn(cluster_spec): 764 distribute_coordinator.run_distribute_coordinator( 765 None, 766 MockStrategy(between_graph=True), 767 mode=INDEPENDENT_WORKER, 768 cluster_spec=cluster_spec, 769 task_type="ps", 770 task_id=0) 771 772 with test.mock.patch.dict( 773 "os.environ", 774 {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( 775 time, "sleep", _fake_sleep): 776 t = threading.Thread(target=_thread_fn, args=(cluster_spec,)) 777 t.start() 778 t.join() 779 self.assertTrue(joined[0]) 780 781 def testRpcLayerEnvironmentVariable(self): 782 cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]} 783 tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"} 784 785 rpc_layer_from_coordinator = [None] 786 787 def _run_mock_server(cluster_spec=None, 788 task_type=None, 789 task_id=None, 790 session_config=None, 791 rpc_layer=None, 792 environment=None): 793 del cluster_spec, task_type, task_id, session_config, environment 794 rpc_layer_from_coordinator[0] = rpc_layer 795 return MockServer() 796 797 with test.mock.patch.dict( 798 "os.environ", 799 {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( 800 distribute_coordinator, "_run_std_server", _run_mock_server): 801 distribute_coordinator.run_distribute_coordinator( 802 None, 803 MockStrategy(between_graph=True), 804 mode=INDEPENDENT_WORKER, 805 cluster_spec=cluster_spec, 806 task_type="ps", 807 task_id=0) 808 self.assertEqual(rpc_layer_from_coordinator[0], "cake") 809 810 811class StrategyConfigureTest(test.TestCase): 812 813 def setUp(self): 814 self._device_filters = [] 815 self._intra_op_parallelism_threads = None 816 self._inter_op_parallelism_threads = None 817 super(StrategyConfigureTest, self).setUp() 818 819 def _dump_device_filters(self, *args, **kwargs): 820 session_config = kwargs.get("session_config", None) 821 self._device_filters.extend(session_config.device_filters) 822 self._intra_op_parallelism_threads = ( 823 session_config.intra_op_parallelism_threads) 824 self._inter_op_parallelism_threads = ( 825 session_config.inter_op_parallelism_threads) 826 return MockServer() 827 828 def _worker_fn(self, strategy): 829 worker_context = distribute_coordinator_context.get_current_worker_context() 830 session_config = worker_context._session_config 831 self._device_filters.extend(session_config.device_filters) 832 self._intra_op_parallelism_threads = ( 833 session_config.intra_op_parallelism_threads) 834 self._inter_op_parallelism_threads = ( 835 session_config.inter_op_parallelism_threads) 836 return MockServer() 837 838 def test_session_config_in_std_server(self): 839 cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]} 840 tf_config = {"cluster": cluster_spec} 841 842 with test.mock.patch.dict( 843 "os.environ", 844 {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( 845 distribute_coordinator, "_run_std_server", 846 self._dump_device_filters): 847 distribute_coordinator.run_distribute_coordinator( 848 lambda _: None, 849 MockStrategy(between_graph=True), 850 mode=INDEPENDENT_WORKER, 851 cluster_spec=cluster_spec, 852 task_type="worker", 853 task_id=0) 854 self.assertEqual(self._intra_op_parallelism_threads, 1) 855 self.assertEqual(self._inter_op_parallelism_threads, 0) 856 857 def test_session_config_in_session_creator(self): 858 cluster_spec = {"worker": ["localhost:0"]} 859 tf_config = {"cluster": cluster_spec} 860 861 # Reset the saved Server state. 862 distribute_coordinator._thread_local = threading.local() # pylint: disable=protected-access 863 864 with test.mock.patch.dict("os.environ", 865 {"TF_CONFIG": json.dumps(tf_config)}): 866 distribute_coordinator.run_distribute_coordinator( 867 self._worker_fn, 868 MockStrategy(between_graph=True), 869 mode=INDEPENDENT_WORKER, 870 cluster_spec=cluster_spec, 871 task_type="worker", 872 task_id=0) 873 self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"]) 874 self.assertEqual(self._intra_op_parallelism_threads, 2) 875 self.assertEqual(self._inter_op_parallelism_threads, 0) 876 877 def test_eval_strategy_configure(self): 878 cluster_spec = {"evaluator": ["localhost:0"]} 879 tf_config = {"cluster": cluster_spec} 880 881 with test.mock.patch.dict("os.environ", 882 {"TF_CONFIG": json.dumps(tf_config)}): 883 distribute_coordinator.run_distribute_coordinator( 884 lambda _: None, 885 MockStrategy(between_graph=False), 886 eval_fn=self._worker_fn, 887 eval_strategy=MockStrategy(between_graph=True), 888 mode=INDEPENDENT_WORKER, 889 cluster_spec=cluster_spec, 890 task_type="evaluator", 891 task_id=0) 892 self.assertEqual(self._device_filters, ["/job:somejob"]) 893 self.assertEqual(self._intra_op_parallelism_threads, 0) 894 self.assertEqual(self._inter_op_parallelism_threads, 2) 895 896 897class RunStandardTensorflowServerTest(test.TestCase): 898 899 def test_std_server_arguments(self): 900 cs = {"worker": ["fake_worker"], "ps": ["fake_ps"]} 901 tf_config = {"cluster": cs, "task": {"type": "ps", "id": 0}} 902 903 def _mock_run_std_server(cluster_spec=None, 904 task_type=None, 905 task_id=None, 906 session_config=None, 907 rpc_layer=None): 908 self.assertEqual(cluster_spec.as_dict(), cs) 909 self.assertEqual(task_type, "ps") 910 self.assertEqual(task_id, 0) 911 self.assertEqual(session_config.experimental.collective_group_leader, 912 "/job:worker/replica:0/task:0") 913 self.assertEqual(session_config.intra_op_parallelism_threads, 1) 914 self.assertEqual(rpc_layer, "grpc") 915 916 return MockServer() 917 918 with test.mock.patch.dict( 919 "os.environ", 920 {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( 921 distribute_coordinator, "_run_std_server", _mock_run_std_server): 922 session_config = config_pb2.ConfigProto() 923 session_config.intra_op_parallelism_threads = 1 924 mock_server = distribute_coordinator.run_standard_tensorflow_server( 925 session_config) 926 self.assertTrue(mock_server.started) 927 928 929if __name__ == "__main__": 930 # TODO(yuefengz): find a smart way to terminate std server threads. 931 with test.mock.patch.object(sys, "exit", os._exit): 932 # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly. 933 orig_init = session_manager.SessionManager.__init__ 934 935 def new_init(*args, **kwargs): 936 kwargs.pop("recovery_wait_secs", None) 937 kwargs["recovery_wait_secs"] = 0.5 938 orig_init(*args, **kwargs) 939 940 session_manager.SessionManager.__init__ = new_init 941 942 test.main() 943