xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/distribute_coordinator_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Distribute Coordinator."""
16
17import contextlib
18import copy
19import json
20import os
21import sys
22import threading
23import time
24
25import six
26
27# pylint: disable=g-import-not-at-top
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.client import session
30from tensorflow.python.distribute import distribute_coordinator
31from tensorflow.python.distribute import distribute_coordinator_context
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
40from tensorflow.python.training import coordinator
41from tensorflow.python.training import monitored_session
42from tensorflow.python.training import session_manager
43
44
45CHIEF = distribute_coordinator._TaskType.CHIEF
46WORKER = distribute_coordinator._TaskType.WORKER
47PS = distribute_coordinator._TaskType.PS
48EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
49
50STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
51INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
52
53NUM_WORKERS = 3
54NUM_PS = 2
55
56original_sys_exit = sys.exit
57
58
59def _bytes_to_str(maybe_bytes):
60  if isinstance(maybe_bytes, six.string_types):
61    return maybe_bytes
62  else:
63    return str(maybe_bytes, "utf-8")
64
65
66def _strip_protocol(target):
67  # cluster_spec expects "host:port" strings.
68  if "//" in target:
69    return target.split("//")[1]
70  else:
71    return target
72
73
74class MockExtended(object):
75
76  def __init__(self,
77               between_graph=False,
78               should_init=None,
79               should_checkpoint=None,
80               should_save_summary=None):
81    self.experimental_between_graph = between_graph
82    self.experimental_should_init = should_init
83    self.should_checkpoint = should_checkpoint
84    self.should_save_summary = should_save_summary
85
86
87class MockStrategy(object):
88
89  def __init__(self,
90               between_graph=False,
91               should_init=None,
92               should_checkpoint=None,
93               should_save_summary=None):
94    self.extended = MockExtended(between_graph, should_init, should_checkpoint,
95                                 should_save_summary)
96
97  def configure(self,
98                session_config=None,
99                cluster_spec=None,
100                task_type=None,
101                task_id=None):
102    if self.extended.experimental_should_init is None:
103      if task_id == 0:
104        self.extended.experimental_should_init = True
105      else:
106        self.extended.experimental_should_init = False
107    if self.extended.should_checkpoint is None:
108      if task_id == 0:
109        self.extended.should_checkpoint = True
110      else:
111        self.extended.should_checkpoint = False
112    if self.extended.should_save_summary is None:
113      if task_id == 0:
114        self.extended.should_save_summary = True
115      else:
116        self.extended.should_save_summary = False
117
118    if session_config:
119      if (cluster_spec and task_type and task_id is not None and
120          self.extended.experimental_between_graph):
121        session_config.intra_op_parallelism_threads += 1
122        if task_type in ["chief", "worker"]:
123          session_config.device_filters.extend(
124              ["/job:%s/task:%d" % (task_type, task_id), "/job:ps"])
125      else:
126        session_config.inter_op_parallelism_threads += 1
127        session_config.device_filters.append("/job:somejob")
128
129
130class MockServer(object):
131
132  def __init__(self):
133    self._joined = False
134    self._started = False
135
136  def start(self):
137    self._started = True
138
139  def join(self):
140    assert not self._joined
141    self._joined = True
142
143  @property
144  def joined(self):
145    return self._joined
146
147  @property
148  def started(self):
149    return self._started
150
151
152class DistributeCoordinatorTestBase(test.TestCase):
153
154  @classmethod
155  def setUpClass(cls):
156    # We have to create a global in-process cluster because once an in-process
157    # tensorflow server is created, there is no way to terminate it. Please see
158    # multi_worker_test_base.py for more details.
159    # TODO(yuefengz): use the utitliy from multi_worker_test_base.
160    cls._workers, cls._ps = test_util.create_local_cluster(
161        NUM_WORKERS, num_ps=NUM_PS)
162    cls._cluster_spec = {
163        WORKER: [
164            _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
165        ],
166        PS: [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps]
167    }
168
169  def setUp(self):
170    self._result_correct = 0
171    self._lock = threading.Lock()
172    self._worker_context = {}
173    self._strategy_property = {}
174    self._std_servers = {}
175    self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
176    self._coord = coordinator.Coordinator()
177
178  @contextlib.contextmanager
179  def _test_session(self, target):
180    config = config_pb2.ConfigProto(allow_soft_placement=True)
181    config.graph_options.optimizer_options.opt_level = -1
182    with session.Session(graph=None, config=config, target=target) as sess:
183      yield sess
184
185  # TODO(yuefengz): use the utitliy from multi_worker_test_base.
186  def _create_cluster_spec(self,
187                           has_chief=False,
188                           num_workers=1,
189                           num_ps=0,
190                           has_eval=False):
191    cluster_spec = {}
192    if has_chief:
193      cluster_spec[CHIEF] = ["localhost:%s" % test_util.pick_unused_port()]
194    if num_workers:
195      cluster_spec[WORKER] = [
196          "localhost:%s" % test_util.pick_unused_port()
197          for _ in range(num_workers)
198      ]
199    if num_ps:
200      cluster_spec[PS] = [
201          "localhost:%s" % test_util.pick_unused_port() for _ in range(num_ps)
202      ]
203    if has_eval:
204      cluster_spec[EVALUATOR] = ["localhost:%s" % test_util.pick_unused_port()]
205    return cluster_spec
206
207  def _in_graph_worker_fn(self, strategy):
208    context = distribute_coordinator_context.get_current_worker_context()
209    self.assertTrue(context is not None)
210    with self._test_session(target=context.master_target) as sess:
211      xs = []
212      expected = 0.0
213      for i in range(context.num_workers):
214        with ops.device("/job:worker/task:%d" % i):
215          x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
216          x_add = x.assign_add(float(i))
217          xs.append(x_add)
218          expected += i + 10.0
219
220      with ops.device("/job:worker/task:0"):
221        result = math_ops.add_n(xs)
222
223      self.evaluate(variables.global_variables_initializer())
224      result_value = sess.run(result)
225    self.assertEqual(result_value, expected)
226    if result_value == expected:
227      self._result_correct += 1
228
229  def _wrapped_worker_fn(self, worker_fn):
230    def wrapped(*args, **kwargs):
231      with self._coord.stop_on_exception():
232        return worker_fn(*args, **kwargs)
233    return wrapped
234
235  def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
236    t = threading.Thread(
237        target=distribute_coordinator.run_distribute_coordinator,
238        args=(self._wrapped_worker_fn(worker_fn), strategy),
239        kwargs=kwargs)
240    t.start()
241    return t
242
243  def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
244                                           cluster_spec, **kwargs):
245    threads = {}
246    for task_type in cluster_spec.keys():
247      threads[task_type] = []
248      for task_id in range(len(cluster_spec[task_type])):
249        t = self._run_coordinator_in_thread(
250            worker_fn,
251            strategy,
252            cluster_spec=cluster_spec,
253            task_type=task_type,
254            task_id=task_id,
255            **kwargs)
256        threads[task_type].append(t)
257    return threads
258
259  def _join_threads(self, threads):
260    try:
261      self._coord.join(threads)
262    except errors.UnknownError as e:
263      if "Could not start gRPC server" in e.message:
264        self.skipTest("Cannot start std servers.")
265      else:
266        raise
267
268  def _between_graph_worker_fn(self, strategy):
269    context = distribute_coordinator_context.get_current_worker_context()
270    self.assertTrue(context is not None)
271    with self._test_session(target=context.master_target) as sess:
272      with ops.device("/job:ps/task:0"):
273        # TODO(yuefengz): investigate why not using resource variable will make
274        # the test flaky.
275        x = variable_scope.get_variable(
276            "x", initializer=10.0, use_resource=True)
277      with ops.device("/job:ps/task:1"):
278        y = variable_scope.get_variable(
279            "y", initializer=20.0, use_resource=True)
280
281      x_add = x.assign_add(2.0)
282      y_sub = y.assign_sub(2.0)
283      train_op = control_flow_ops.group([x_add, y_sub])
284
285      if context.is_chief:
286        self.evaluate(variables.global_variables_initializer())
287
288      # Synchronize workers after initializaton.
289      if context.has_barrier:
290        context.wait_for_other_workers()
291      else:
292        while True:
293          uninit_vars = sess.run(variables.report_uninitialized_variables())
294          # pylint: disable=g-explicit-length-test
295          if len(uninit_vars) == 0:
296            break
297
298      sess.run(train_op)
299
300      # Synchronize workers after one step to make sure they all have finished
301      # training.
302      if context.has_barrier:
303        context.wait_for_other_workers()
304      else:
305        self._barrier.wait()
306
307      x_val, y_val = sess.run([x, y])
308
309      self.assertEqual(x_val, 16.0)
310      self.assertEqual(y_val, 14.0)
311      if x_val == 16.0 and y_val == 14.0:
312        with self._lock:
313          self._result_correct += 1
314
315  def _between_graph_with_monitored_session(self, strategy):
316    context = distribute_coordinator_context.get_current_worker_context()
317    self.assertTrue(context is not None)
318    with ops.device("/job:ps/task:0"):
319      # TODO(yuefengz): investigate why not using resource variable will make
320      # the test flaky.
321      x = variable_scope.get_variable("xx", initializer=10.0, use_resource=True)
322    with ops.device("/job:ps/task:1"):
323      y = variable_scope.get_variable("yy", initializer=20.0, use_resource=True)
324
325    x_add = x.assign_add(2.0)
326    y_sub = y.assign_sub(2.0)
327    train_op = control_flow_ops.group([x_add, y_sub])
328
329    # The monitored session will run init or ready ops.
330    with monitored_session.MonitoredSession() as sess:
331      sess.run(train_op)
332
333      # Synchronize workers after one step to make sure they all have finished
334      # training.
335      if context.has_barrier:
336        context.wait_for_other_workers()
337      else:
338        self._barrier.wait()
339
340      x_val, y_val = sess.run([x, y])
341
342    self.assertEqual(x_val, 16.0)
343    self.assertEqual(y_val, 14.0)
344    if x_val == 16.0 and y_val == 14.0:
345      with self._lock:
346        self._result_correct += 1
347
348  def _dump_worker_context(self, strategy):
349    """Dumps the propoerties of each worker context.
350
351    It dumps the context properties to a dict mapping from task_type to a list
352    of tuples of master_target, num_workers, is_chief and distribute_mode, where
353    the list is indexed by the task_id.
354
355    Args:
356      strategy: a `DistributionStrategy` object.
357    """
358    context = distribute_coordinator_context.get_current_worker_context()
359    self.assertTrue(context is not None)
360    task_type = str(context.task_type)
361    task_id = context.task_id or 0
362    with self._lock:
363      if task_type not in self._worker_context:
364        self._worker_context[task_type] = []
365      while len(self._worker_context[task_type]) <= task_id:
366        self._worker_context[task_type].append(None)
367      self._worker_context[task_type][task_id] = (context.master_target,
368                                                  context.num_workers,
369                                                  context.is_chief,
370                                                  context.distributed_mode)
371
372  def _dump_strategy_property(self, strategy):
373    context = distribute_coordinator_context.get_current_worker_context()
374    self.assertTrue(context is not None)
375
376    self.assertEqual(context._strategy.extended.experimental_should_init,
377                     strategy.extended.experimental_should_init)
378    self.assertEqual(context.should_checkpoint,
379                     strategy.extended.should_checkpoint)
380    self.assertEqual(context.should_save_summary,
381                     strategy.extended.should_save_summary)
382
383    task_type = str(context.task_type)
384    task_id = context.task_id or 0
385    with self._lock:
386      if task_type not in self._strategy_property:
387        self._strategy_property[task_type] = []
388      while len(self._strategy_property[task_type]) <= task_id:
389        self._strategy_property[task_type].append(None)
390      self._strategy_property[task_type][task_id] = (
391          context._strategy.extended.experimental_should_init,
392          context.should_checkpoint,
393          context.should_save_summary)
394
395  def _run_mock_std_server(self,
396                           session_config=None,
397                           cluster_spec=None,
398                           task_type=None,
399                           task_id=None,
400                           rpc_layer=None,
401                           environment=None):
402    task_type = str(task_type)
403    task_id = task_id or 0
404    with self._lock:
405      if task_type not in self._std_servers:
406        self._std_servers[task_type] = []
407      while len(self._std_servers[task_type]) <= task_id:
408        self._std_servers[task_type].append(None)
409
410      server = MockServer()
411      self._std_servers[task_type][task_id] = server
412    return server
413
414
415class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
416
417  def testInGraphStandaloneMode(self):
418    """Test it runs in-graph replication in standalone client mode."""
419    distribute_coordinator.run_distribute_coordinator(
420        self._in_graph_worker_fn,
421        MockStrategy(between_graph=False),
422        cluster_spec=self._cluster_spec)
423    self.assertEqual(self._result_correct, 1)
424
425  def testBetweenGraph(self):
426    """Test it runs between-graph replication in standalone client mode."""
427    distribute_coordinator.run_distribute_coordinator(
428        self._between_graph_worker_fn,
429        MockStrategy(between_graph=True),
430        cluster_spec=self._cluster_spec)
431
432    # Each finished worker will increment self._result_correct.
433    self.assertEqual(self._result_correct, NUM_WORKERS)
434
435  @test_util.run_v1_only("MonitoredSession removed from v2")
436  def testBetweenGraphWithMonitoredSession(self):
437    """Test monitored session in standalone client mode."""
438    distribute_coordinator.run_distribute_coordinator(
439        self._between_graph_with_monitored_session,
440        MockStrategy(between_graph=True),
441        cluster_spec=self._cluster_spec)
442
443    # Each finished worker will increment self._result_correct.
444    self.assertEqual(self._result_correct, NUM_WORKERS)
445
446  def testBetweenGraphContext(self):
447    # Dumps the task contexts to the self._worker_context dict.
448    distribute_coordinator.run_distribute_coordinator(
449        self._dump_worker_context,
450        MockStrategy(between_graph=True),
451        cluster_spec=self._cluster_spec)
452
453    # There is only one type of task and there three such tasks.
454    self.assertEqual(len(self._worker_context), 1)
455    self.assertTrue(WORKER in self._worker_context)
456    self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
457
458    # Check whether each task has the right master_target, num_workers, is_chief
459    # and distributed_mode.
460    self.assertEqual(
461        self._worker_context[WORKER][0],
462        (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
463    self.assertEqual(
464        self._worker_context[WORKER][1],
465        (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
466    self.assertEqual(
467        self._worker_context[WORKER][2],
468        (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
469
470  def testBetweenGraphStrategyProperties(self):
471    # Dumps properties of the strategy objects.
472    distribute_coordinator.run_distribute_coordinator(
473        self._dump_strategy_property,
474        MockStrategy(between_graph=True, should_init=True),
475        cluster_spec=self._cluster_spec)
476
477    # There is only one type of task and there three such tasks.
478    self.assertEqual(len(self._strategy_property), 1)
479    self.assertTrue(WORKER in self._strategy_property)
480    self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
481
482    # Check whether each task has the right properties of should_init,
483    # should_checkpoint and should_save_summary.
484    self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
485    self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
486    self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
487
488  def testInGraphContext(self):
489    # Dumps the task contexts to the self._worker_context dict.
490    distribute_coordinator.run_distribute_coordinator(
491        self._dump_worker_context,
492        MockStrategy(between_graph=False),
493        cluster_spec=self._cluster_spec)
494
495    # There is only a "None" task in the dumped task context.
496    self.assertEqual(len(self._worker_context), 1)
497    self.assertTrue("None" in self._worker_context)
498    self.assertEqual(len(self._worker_context["None"]), 1)
499
500    # Check whether each task has the right master_target, num_workers, is_chief
501    # and distributed_mode.
502    self.assertEqual(
503        self._worker_context["None"][0],
504        (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
505
506  def testLocalContext(self):
507    # Dumps the task contexts to the self._worker_context dict.
508    distribute_coordinator.run_distribute_coordinator(
509        self._dump_worker_context,
510        MockStrategy(between_graph=False),
511        cluster_spec=None)
512
513    # There is only a "None" task.
514    self.assertEqual(len(self._worker_context), 1)
515    self.assertTrue("None" in self._worker_context)
516    self.assertEqual(len(self._worker_context["None"]), 1)
517
518    # Check whether each task has the right master_target, num_workers, is_chief
519    # and distributed_mode.
520    self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
521
522  def testBetweenGraphContextWithChief(self):
523    # Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
524    cluster_spec = copy.deepcopy(self._cluster_spec)
525    cluster_spec[CHIEF] = ["fake_chief"]
526
527    # Dumps the task contexts to the self._worker_context dict.
528    distribute_coordinator.run_distribute_coordinator(
529        self._dump_worker_context,
530        MockStrategy(between_graph=True),
531        cluster_spec=cluster_spec,
532        rpc_layer="grpc")
533
534    # There are one CHIEF and three workers.
535    self.assertEqual(len(self._worker_context), 2)
536    self.assertTrue(CHIEF in self._worker_context)
537    self.assertTrue(WORKER in self._worker_context)
538    self.assertEqual(len(self._worker_context[CHIEF]), 1)
539    self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
540
541    # Check whether each task has the right master_target, num_workers, is_chief
542    # and distributed_mode.
543    self.assertEqual(self._worker_context[CHIEF][0],
544                     ("grpc://fake_chief", 4, True, True))
545    self.assertEqual(
546        self._worker_context[WORKER][0],
547        (_bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True))
548    self.assertEqual(
549        self._worker_context[WORKER][1],
550        (_bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True))
551    self.assertEqual(
552        self._worker_context[WORKER][2],
553        (_bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True))
554
555  def testInGraphContextWithEval(self):
556    # Adds a EVALUATOR job.
557    cluster_spec = copy.deepcopy(self._cluster_spec)
558    cluster_spec[EVALUATOR] = ["fake_evaluator"]
559
560    # Dumps the task contexts to the self._worker_context dict.
561    distribute_coordinator.run_distribute_coordinator(
562        self._dump_worker_context,
563        MockStrategy(between_graph=False),
564        cluster_spec=cluster_spec,
565        rpc_layer=None)
566
567    # There are one "None" task and one EVALUATOR task.
568    self.assertEqual(len(self._worker_context), 2)
569    self.assertTrue("None" in self._worker_context)
570    self.assertTrue(EVALUATOR in self._worker_context)
571    self.assertEqual(len(self._worker_context["None"]), 1)
572    self.assertEqual(len(self._worker_context[EVALUATOR]), 1)
573
574    # Check whether each task has the right master_target, num_workers, is_chief
575    # and distributed_mode.
576    self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
577        _bytes_to_str(self._workers[0].target)), 3, True, True))
578    self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
579
580
581class DistributeCoordinatorTestIndependentWorkerMode(
582    DistributeCoordinatorTestBase):
583
584  def testInGraph(self):
585    cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
586    threads = self._run_multiple_coordinator_in_threads(
587        self._in_graph_worker_fn,
588        MockStrategy(between_graph=False),
589        cluster_spec,
590        mode=INDEPENDENT_WORKER)
591    self._join_threads([threads[WORKER][0]])
592    self.assertEqual(self._result_correct, 1)
593
594  def testBetweenGraph(self):
595    cluster_spec = self._create_cluster_spec(
596        num_workers=NUM_WORKERS, num_ps=NUM_PS)
597    threads = self._run_multiple_coordinator_in_threads(
598        self._between_graph_worker_fn,
599        MockStrategy(between_graph=True),
600        cluster_spec,
601        mode=INDEPENDENT_WORKER)
602    self._join_threads(threads[WORKER])
603
604    # Each finished worker will increment self._result_correct.
605    self.assertEqual(self._result_correct, NUM_WORKERS)
606
607  @test_util.run_v1_only("MonitoredSession removed from v2")
608  def testBetweenGraphWithMonitoredSession(self):
609    cluster_spec = self._create_cluster_spec(
610        num_workers=NUM_WORKERS, num_ps=NUM_PS)
611    threads = self._run_multiple_coordinator_in_threads(
612        self._between_graph_with_monitored_session,
613        MockStrategy(between_graph=True),
614        cluster_spec,
615        mode=INDEPENDENT_WORKER)
616    self._join_threads(threads[WORKER])
617
618    # Each finished worker will increment self._result_correct.
619    self.assertEqual(self._result_correct, NUM_WORKERS)
620
621  def testBetweenGraphContext(self):
622    cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
623    # Dumps the task contexts and std server arguments.
624    with test.mock.patch.object(distribute_coordinator, "_run_std_server",
625                                self._run_mock_std_server):
626      threads = self._run_multiple_coordinator_in_threads(
627          self._dump_worker_context,
628          MockStrategy(between_graph=True),
629          cluster_spec,
630          mode=INDEPENDENT_WORKER,
631          rpc_layer=None)
632      self._join_threads(threads[WORKER])
633
634    # There is only one type of task and three such tasks.
635    self.assertEqual(len(self._worker_context), 1)
636    self.assertTrue(WORKER in self._worker_context)
637    self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
638
639    # Check whether each task has the right master_target, num_workers, is_chief
640    # and distributed_mode.
641    self.assertEqual(
642        self._worker_context[WORKER][0],
643        (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True))
644    self.assertEqual(
645        self._worker_context[WORKER][1],
646        (_bytes_to_str(cluster_spec[WORKER][1]), NUM_WORKERS, False, True))
647    self.assertEqual(
648        self._worker_context[WORKER][2],
649        (_bytes_to_str(cluster_spec[WORKER][2]), NUM_WORKERS, False, True))
650
651    # Make sure each worker runs a std server.
652    self.assertEqual(len(self._std_servers), 1)
653    self.assertTrue(WORKER in self._std_servers)
654    self.assertEqual(len(self._std_servers[WORKER]), 3)
655    self.assertFalse(self._std_servers[WORKER][0].joined)
656    self.assertFalse(self._std_servers[WORKER][1].joined)
657    self.assertFalse(self._std_servers[WORKER][2].joined)
658
659  def testBetweenGraphStrategyProperties(self):
660    cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
661    # Dumps properties of the strategy objects.
662    with test.mock.patch.object(distribute_coordinator, "_run_std_server",
663                                self._run_mock_std_server):
664      threads = self._run_multiple_coordinator_in_threads(
665          self._dump_strategy_property,
666          MockStrategy(between_graph=True, should_init=True),
667          cluster_spec,
668          mode=INDEPENDENT_WORKER,
669          rpc_layer=None)
670      self._join_threads(threads[WORKER])
671
672    # There is only one type of task and there three such tasks.
673    self.assertEqual(len(self._strategy_property), 1)
674    self.assertTrue(WORKER in self._strategy_property)
675    self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
676
677    # Check whether each task has the right properties of should_init,
678    # should_checkpoint and should_save_summary.
679    self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
680    self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
681    self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
682
683  def testInGraphContext(self):
684    cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
685    # Dumps the task contexts and std server arguments.
686    with test.mock.patch.object(distribute_coordinator, "_run_std_server",
687                                self._run_mock_std_server):
688      threads = self._run_multiple_coordinator_in_threads(
689          self._dump_worker_context,
690          MockStrategy(between_graph=False),
691          cluster_spec,
692          mode=INDEPENDENT_WORKER,
693          rpc_layer=None)
694      self._join_threads(threads[WORKER])
695
696    # There is only a "None" task in the dumped task context.
697    self.assertEqual(len(self._worker_context), 1)
698    self.assertTrue("None" in self._worker_context)
699    self.assertEqual(len(self._worker_context["None"]), 1)
700
701    # Check whether each task has the right master_target, num_workers, is_chief
702    # and distributed_mode.
703    self.assertEqual(
704        self._worker_context["None"][0],
705        (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True))
706
707    # Make sure each worker runs a std server.
708    self.assertEqual(len(self._std_servers), 1)
709    self.assertTrue(WORKER in self._std_servers)
710    self.assertEqual(len(self._std_servers[WORKER]), 3)
711    self.assertFalse(self._std_servers[WORKER][0].joined)
712    self.assertTrue(self._std_servers[WORKER][1].joined)
713    self.assertTrue(self._std_servers[WORKER][2].joined)
714
715  def testInGraphContextWithEval(self):
716    # Adds a EVALUATOR job.
717    cluster_spec = self._create_cluster_spec(
718        num_workers=NUM_WORKERS, has_eval=True)
719
720    # Dumps the task contexts and std server arguments.
721    with test.mock.patch.object(distribute_coordinator, "_run_std_server",
722                                self._run_mock_std_server):
723      threads = self._run_multiple_coordinator_in_threads(
724          self._dump_worker_context,
725          MockStrategy(between_graph=False),
726          cluster_spec,
727          mode=INDEPENDENT_WORKER,
728          rpc_layer=None)
729      self._join_threads(threads[WORKER])
730      self._join_threads([threads[EVALUATOR][0]])
731
732    # There are one "None" task and one EVALUATOR task.
733    self.assertEqual(len(self._worker_context), 2)
734    self.assertTrue("None" in self._worker_context)
735    self.assertTrue(EVALUATOR in self._worker_context)
736    self.assertEqual(len(self._worker_context["None"]), 1)
737    self.assertEqual(len(self._worker_context[EVALUATOR]), 1)
738
739    # Check whether each task has the right master_target, num_workers, is_chief
740    # and distributed_mode.
741    self.assertEqual(self._worker_context["None"][0],
742                     (_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
743    self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
744
745    # Make sure each worker runs a std server.
746    self.assertEqual(len(self._std_servers), 1)
747    self.assertTrue(WORKER in self._std_servers)
748    self.assertEqual(len(self._std_servers[WORKER]), 3)
749    self.assertFalse(self._std_servers[WORKER][0].joined)
750    self.assertTrue(self._std_servers[WORKER][1].joined)
751    self.assertTrue(self._std_servers[WORKER][2].joined)
752
753  def testRunStdServerInGoogleEnvironment(self):
754    cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
755    tf_config = {"cluster": cluster_spec, "environment": "google"}
756
757    joined = [False]
758
759    def _fake_sleep(_):
760      joined[0] = True
761      original_sys_exit(0)
762
763    def _thread_fn(cluster_spec):
764      distribute_coordinator.run_distribute_coordinator(
765          None,
766          MockStrategy(between_graph=True),
767          mode=INDEPENDENT_WORKER,
768          cluster_spec=cluster_spec,
769          task_type="ps",
770          task_id=0)
771
772    with test.mock.patch.dict(
773        "os.environ",
774        {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
775            time, "sleep", _fake_sleep):
776      t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
777      t.start()
778      t.join()
779    self.assertTrue(joined[0])
780
781  def testRpcLayerEnvironmentVariable(self):
782    cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
783    tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}
784
785    rpc_layer_from_coordinator = [None]
786
787    def _run_mock_server(cluster_spec=None,
788                         task_type=None,
789                         task_id=None,
790                         session_config=None,
791                         rpc_layer=None,
792                         environment=None):
793      del cluster_spec, task_type, task_id, session_config, environment
794      rpc_layer_from_coordinator[0] = rpc_layer
795      return MockServer()
796
797    with test.mock.patch.dict(
798        "os.environ",
799        {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
800            distribute_coordinator, "_run_std_server", _run_mock_server):
801      distribute_coordinator.run_distribute_coordinator(
802          None,
803          MockStrategy(between_graph=True),
804          mode=INDEPENDENT_WORKER,
805          cluster_spec=cluster_spec,
806          task_type="ps",
807          task_id=0)
808    self.assertEqual(rpc_layer_from_coordinator[0], "cake")
809
810
811class StrategyConfigureTest(test.TestCase):
812
813  def setUp(self):
814    self._device_filters = []
815    self._intra_op_parallelism_threads = None
816    self._inter_op_parallelism_threads = None
817    super(StrategyConfigureTest, self).setUp()
818
819  def _dump_device_filters(self, *args, **kwargs):
820    session_config = kwargs.get("session_config", None)
821    self._device_filters.extend(session_config.device_filters)
822    self._intra_op_parallelism_threads = (
823        session_config.intra_op_parallelism_threads)
824    self._inter_op_parallelism_threads = (
825        session_config.inter_op_parallelism_threads)
826    return MockServer()
827
828  def _worker_fn(self, strategy):
829    worker_context = distribute_coordinator_context.get_current_worker_context()
830    session_config = worker_context._session_config
831    self._device_filters.extend(session_config.device_filters)
832    self._intra_op_parallelism_threads = (
833        session_config.intra_op_parallelism_threads)
834    self._inter_op_parallelism_threads = (
835        session_config.inter_op_parallelism_threads)
836    return MockServer()
837
838  def test_session_config_in_std_server(self):
839    cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
840    tf_config = {"cluster": cluster_spec}
841
842    with test.mock.patch.dict(
843        "os.environ",
844        {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
845            distribute_coordinator, "_run_std_server",
846            self._dump_device_filters):
847      distribute_coordinator.run_distribute_coordinator(
848          lambda _: None,
849          MockStrategy(between_graph=True),
850          mode=INDEPENDENT_WORKER,
851          cluster_spec=cluster_spec,
852          task_type="worker",
853          task_id=0)
854    self.assertEqual(self._intra_op_parallelism_threads, 1)
855    self.assertEqual(self._inter_op_parallelism_threads, 0)
856
857  def test_session_config_in_session_creator(self):
858    cluster_spec = {"worker": ["localhost:0"]}
859    tf_config = {"cluster": cluster_spec}
860
861    # Reset the saved Server state.
862    distribute_coordinator._thread_local = threading.local()  # pylint: disable=protected-access
863
864    with test.mock.patch.dict("os.environ",
865                              {"TF_CONFIG": json.dumps(tf_config)}):
866      distribute_coordinator.run_distribute_coordinator(
867          self._worker_fn,
868          MockStrategy(between_graph=True),
869          mode=INDEPENDENT_WORKER,
870          cluster_spec=cluster_spec,
871          task_type="worker",
872          task_id=0)
873    self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"])
874    self.assertEqual(self._intra_op_parallelism_threads, 2)
875    self.assertEqual(self._inter_op_parallelism_threads, 0)
876
877  def test_eval_strategy_configure(self):
878    cluster_spec = {"evaluator": ["localhost:0"]}
879    tf_config = {"cluster": cluster_spec}
880
881    with test.mock.patch.dict("os.environ",
882                              {"TF_CONFIG": json.dumps(tf_config)}):
883      distribute_coordinator.run_distribute_coordinator(
884          lambda _: None,
885          MockStrategy(between_graph=False),
886          eval_fn=self._worker_fn,
887          eval_strategy=MockStrategy(between_graph=True),
888          mode=INDEPENDENT_WORKER,
889          cluster_spec=cluster_spec,
890          task_type="evaluator",
891          task_id=0)
892    self.assertEqual(self._device_filters, ["/job:somejob"])
893    self.assertEqual(self._intra_op_parallelism_threads, 0)
894    self.assertEqual(self._inter_op_parallelism_threads, 2)
895
896
897class RunStandardTensorflowServerTest(test.TestCase):
898
899  def test_std_server_arguments(self):
900    cs = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
901    tf_config = {"cluster": cs, "task": {"type": "ps", "id": 0}}
902
903    def _mock_run_std_server(cluster_spec=None,
904                             task_type=None,
905                             task_id=None,
906                             session_config=None,
907                             rpc_layer=None):
908      self.assertEqual(cluster_spec.as_dict(), cs)
909      self.assertEqual(task_type, "ps")
910      self.assertEqual(task_id, 0)
911      self.assertEqual(session_config.experimental.collective_group_leader,
912                       "/job:worker/replica:0/task:0")
913      self.assertEqual(session_config.intra_op_parallelism_threads, 1)
914      self.assertEqual(rpc_layer, "grpc")
915
916      return MockServer()
917
918    with test.mock.patch.dict(
919        "os.environ",
920        {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
921            distribute_coordinator, "_run_std_server", _mock_run_std_server):
922      session_config = config_pb2.ConfigProto()
923      session_config.intra_op_parallelism_threads = 1
924      mock_server = distribute_coordinator.run_standard_tensorflow_server(
925          session_config)
926      self.assertTrue(mock_server.started)
927
928
929if __name__ == "__main__":
930  # TODO(yuefengz): find a smart way to terminate std server threads.
931  with test.mock.patch.object(sys, "exit", os._exit):
932    # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly.
933    orig_init = session_manager.SessionManager.__init__
934
935    def new_init(*args, **kwargs):
936      kwargs.pop("recovery_wait_secs", None)
937      kwargs["recovery_wait_secs"] = 0.5
938      orig_init(*args, **kwargs)
939
940    session_manager.SessionManager.__init__ = new_init
941
942    test.main()
943