xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_process_runner_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `multi_process_runner`."""
16
17import ctypes
18import json
19import os
20import sys
21import threading
22import time
23import unittest
24
25from absl import logging
26from absl.testing import parameterized
27
28from tensorflow.python.distribute import combinations
29from tensorflow.python.distribute import multi_process_runner
30from tensorflow.python.distribute import multi_worker_test_base
31from tensorflow.python.eager import context
32from tensorflow.python.eager import test
33
34
35def fn_that_adds_task_type_in_return_data():
36  return multi_worker_test_base.get_task_type()
37
38
39def fn_that_errors():
40  raise ValueError('This is an error.')
41
42
43def fn_that_does_nothing():
44  pass
45
46
47def fn_that_adds_simple_return_data():
48  return 'dummy_data'
49
50
51def fn_that_returns_args_and_kwargs(*args, **kwargs):
52  return list(args) + list(kwargs.items())
53
54
55def fn_with_barrier():
56  return multi_process_runner.get_barrier()
57
58
59def fn_that_returns_pid():
60  return os.getpid()
61
62
63V = None
64
65
66def fn_that_sets_global(val):
67  global V
68  old_val = V
69  V = val
70  return old_val
71
72
73@combinations.generate(combinations.combine(required_gpus=0))
74class MultiProcessRunnerTest(test.TestCase, parameterized.TestCase):
75
76  def _worker_idx(self):
77    config_task = json.loads(os.environ['TF_CONFIG'])['task']
78    return config_task['index']
79
80  def test_multi_process_runner(self):
81    mpr_result = multi_process_runner.run(
82        fn_that_adds_task_type_in_return_data,
83        multi_worker_test_base.create_cluster_spec(
84            num_workers=2, num_ps=3, has_chief=True))
85
86    job_count_dict = {'worker': 2, 'ps': 3, 'chief': 1}
87    for data in mpr_result.return_value:
88      job_count_dict[data] -= 1
89
90    self.assertEqual(job_count_dict['worker'], 0)
91    self.assertEqual(job_count_dict['ps'], 0)
92    self.assertEqual(job_count_dict['chief'], 0)
93
94  def test_multi_process_runner_error_propagates_from_subprocesses(self):
95    runner = multi_process_runner.MultiProcessRunner(
96        fn_that_errors,
97        multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
98        max_run_time=20)
99    runner.start()
100    with self.assertRaisesRegex(ValueError, 'This is an error.'):
101      runner.join()
102
103  def test_multi_process_runner_queue_emptied_between_runs(self):
104    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
105    return_value = multi_process_runner.run(fn_that_adds_simple_return_data,
106                                            cluster_spec).return_value
107    self.assertTrue(return_value)
108    self.assertEqual(return_value[0], 'dummy_data')
109    self.assertEqual(return_value[1], 'dummy_data')
110    return_value = multi_process_runner.run(fn_that_does_nothing,
111                                            cluster_spec).return_value
112    self.assertFalse(return_value)
113
114  def test_multi_process_runner_args_passed_correctly(self):
115    return_value = multi_process_runner.run(
116        fn_that_returns_args_and_kwargs,
117        multi_worker_test_base.create_cluster_spec(num_workers=1),
118        args=('a', 'b'),
119        kwargs={
120            'c_k': 'c_v'
121        }).return_value
122    self.assertEqual(return_value[0][0], 'a')
123    self.assertEqual(return_value[0][1], 'b')
124    self.assertEqual(return_value[0][2], ('c_k', 'c_v'))
125
126  def test_stdout_captured(self):
127
128    def simple_print_func():
129      print('This is something printed.', flush=True)
130      return 'This is returned data.'
131
132    mpr_result = multi_process_runner.run(
133        simple_print_func,
134        multi_worker_test_base.create_cluster_spec(num_workers=2),
135        return_output=True)
136    std_stream_results = mpr_result.stdout
137    return_value = mpr_result.return_value
138    self.assertIn('[worker-0]:    This is something printed.\n',
139                  std_stream_results)
140    self.assertIn('[worker-1]:    This is something printed.\n',
141                  std_stream_results)
142    self.assertIn('This is returned data.', return_value)
143
144  def test_termination(self):
145
146    def fn():
147      for i in range(0, 10):
148        print(
149            'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
150        time.sleep(5)
151
152    mpr = multi_process_runner.MultiProcessRunner(
153        fn,
154        multi_worker_test_base.create_cluster_spec(num_workers=2),
155        return_output=True)
156    mpr.start()
157    time.sleep(5)
158    mpr.terminate('worker', 0)
159
160    std_stream_results = mpr.join().stdout
161
162    # Worker 0 is terminated in the middle, so it should not have iteration 9
163    # printed.
164    self.assertIn('[worker-0]:    index 0, iteration 0\n', std_stream_results)
165    self.assertNotIn('[worker-0]:    index 0, iteration 9\n',
166                     std_stream_results)
167    self.assertIn('[worker-1]:    index 1, iteration 0\n', std_stream_results)
168    self.assertIn('[worker-1]:    index 1, iteration 9\n', std_stream_results)
169
170  def test_termination_and_start_single_process(self):
171
172    def fn():
173      for i in range(0, 10):
174        print(
175            'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
176        time.sleep(1)
177
178    mpr = multi_process_runner.MultiProcessRunner(
179        fn,
180        multi_worker_test_base.create_cluster_spec(num_workers=2),
181        return_output=True)
182    mpr.start()
183    time.sleep(3)
184    mpr.terminate('worker', 0)
185    mpr.start_single_process('worker', 0)
186    std_stream_results = mpr.join().stdout
187
188    # Worker 0 is terminated in the middle, but a new worker 0 is added, so it
189    # should still have iteration 9 printed. Moreover, iteration 0 of worker 0
190    # should happen twice.
191    self.assertLen(
192        [s for s in std_stream_results if 'index 0, iteration 0' in s], 2)
193    self.assertIn('[worker-0]:    index 0, iteration 9\n', std_stream_results)
194    self.assertIn('[worker-1]:    index 1, iteration 0\n', std_stream_results)
195    self.assertIn('[worker-1]:    index 1, iteration 9\n', std_stream_results)
196
197  def test_streaming(self):
198
199    def fn():
200      for i in range(5):
201        logging.info('(logging) %s-%d, i: %d',
202                     multi_worker_test_base.get_task_type(), self._worker_idx(),
203                     i)
204        print(
205            '(print) {}-{}, i: {}'.format(
206                multi_worker_test_base.get_task_type(), self._worker_idx(), i),
207            flush=True)
208        time.sleep(1)
209
210    mpr = multi_process_runner.MultiProcessRunner(
211        fn,
212        multi_worker_test_base.create_cluster_spec(
213            has_chief=True, num_workers=2, num_ps=2),
214        return_output=True)
215    mpr._dependence_on_chief = False
216
217    mpr.start()
218    mpr.start_single_process('worker', 2)
219    mpr.start_single_process('ps', 2)
220    mpr_result = mpr.join()
221
222    list_to_assert = mpr_result.stdout
223
224    for job in ['chief']:
225      for iteration in range(5):
226        self.assertTrue(
227            any('(logging) {}-0, i: {}'.format(job, iteration) in line
228                for line in list_to_assert))
229        self.assertTrue(
230            any('(print) {}-0, i: {}'.format(job, iteration) in line
231                for line in list_to_assert))
232
233    for job in ['worker', 'ps']:
234      for iteration in range(5):
235        for task in range(3):
236          self.assertTrue(
237              any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line
238                  for line in list_to_assert))
239          self.assertTrue(
240              any('(print) {}-{}, i: {}'.format(job, task, iteration) in line
241                  for line in list_to_assert))
242        task = 3
243        self.assertFalse(
244            any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line
245                for line in list_to_assert))
246        self.assertFalse(
247            any('(print) {}-{}, i: {}'.format(job, task, iteration) in line
248                for line in list_to_assert))
249
250  def test_start_in_process_as(self):
251
252    def fn():
253      for i in range(5):
254        logging.info('%s-%d, i: %d', multi_worker_test_base.get_task_type(),
255                     self._worker_idx(), i)
256        time.sleep(1)
257
258    mpr = multi_process_runner.MultiProcessRunner(
259        fn,
260        multi_worker_test_base.create_cluster_spec(
261            has_chief=True, num_workers=1),
262        return_output=True)
263
264    def eval_func():
265      time.sleep(1)
266      mpr.start_single_process(task_type='evaluator', task_id=0)
267
268    eval_thread = threading.Thread(target=eval_func)
269    eval_thread.start()
270    mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
271    eval_thread.join()
272    list_to_assert = mpr.join().stdout
273    for job in ['worker', 'evaluator']:
274      for iteration in range(5):
275        self.assertTrue(
276            any('{}-0, i: {}'.format(job, iteration) in line
277                for line in list_to_assert))
278
279  def test_terminate_all_does_not_ignore_error(self):
280    mpr = multi_process_runner.MultiProcessRunner(
281        fn_that_errors,
282        multi_worker_test_base.create_cluster_spec(num_workers=2),
283        return_output=True)
284    mpr.start()
285    time.sleep(60)
286    mpr.terminate_all()
287    with self.assertRaisesRegex(ValueError, 'This is an error.'):
288      mpr.join()
289
290  def test_barrier(self):
291    multi_process_runner.run(
292        fn_with_barrier,
293        cluster_spec=multi_worker_test_base.create_cluster_spec(
294            has_chief=True, num_workers=1),
295    )
296
297  def test_barrier_called_in_main_process(self):
298    with self.assertRaises(ValueError):
299      multi_process_runner.get_barrier()
300
301  def test_stdout_available_when_timeout(self):
302
303    def fn():
304      logging.info('something printed')
305      time.sleep(10000)  # Intentionally make the test timeout.
306
307    with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm:
308      mpr = multi_process_runner.MultiProcessRunner(
309          fn,
310          multi_worker_test_base.create_cluster_spec(num_workers=1),
311          return_output=True)
312      mpr.start()
313      mpr.join(timeout=60)
314    mpr.terminate_all()
315
316    list_to_assert = cm.exception.mpr_result.stdout
317    self.assertTrue(
318        any('something printed' in line for line in list_to_assert))
319
320  def test_seg_fault_raises_error(self):
321
322    if multi_process_runner.is_oss() or sys.version_info >= (3, 7):
323      self.skipTest('TODO(b/171004637): Failing in OSS and Python 3.7+')
324
325    def fn_expected_to_seg_fault():
326      ctypes.string_at(0)  # Intentionally made seg fault.
327
328    with self.assertRaises(
329        multi_process_runner.UnexpectedSubprocessExitError) as cm:
330      multi_process_runner.run(
331          fn_expected_to_seg_fault,
332          multi_worker_test_base.create_cluster_spec(num_workers=1),
333          return_output=True)
334    self.assertIn('Subprocess worker-0 exited with exit code',
335                  str(cm.exception))
336    list_to_assert = cm.exception.mpr_result.stdout
337    self.assertTrue(
338        any('Segmentation fault' in line for line in list_to_assert))
339
340  def test_seg_fault_in_chief_raises_error(self):
341
342    if multi_process_runner.is_oss() or sys.version_info >= (3, 7):
343      self.skipTest('TODO(b/171004637): Failing in OSS and Python 3.7+')
344
345    def fn_expected_to_seg_fault():
346      if multi_worker_test_base.get_task_type() == 'worker':
347        time.sleep(10000)
348      ctypes.string_at(0)  # Intentionally made seg fault.
349
350    with self.assertRaises(
351        multi_process_runner.UnexpectedSubprocessExitError) as cm:
352      multi_process_runner.run(
353          fn_expected_to_seg_fault,
354          multi_worker_test_base.create_cluster_spec(
355              has_chief=True, num_workers=1),
356          return_output=True)
357    self.assertIn('Subprocess chief-0 exited with exit code',
358                  str(cm.exception))
359    list_to_assert = cm.exception.mpr_result.stdout
360    self.assertTrue(
361        any('Segmentation fault' in line for line in list_to_assert))
362
363  def test_exit_code_is_reported_by_chief_subprocess(self):
364
365    def fn_expected_to_exit_with_20():
366      if multi_worker_test_base.get_task_type() == 'worker':
367        time.sleep(10000)
368      sys.exit(20)
369
370    mpr = multi_process_runner.MultiProcessRunner(
371        fn_expected_to_exit_with_20,
372        multi_worker_test_base.create_cluster_spec(
373            has_chief=True, num_workers=1))
374    mpr.start()
375
376    with self.assertRaisesRegex(
377        multi_process_runner.UnexpectedSubprocessExitError,
378        'Subprocess chief-0 exited with exit code 20'):
379      mpr.join()
380
381  def test_exit_code_is_reported_by_subprocess(self):
382
383    def fn_expected_to_exit_with_10():
384      sys.exit(10)
385
386    mpr = multi_process_runner.MultiProcessRunner(
387        fn_expected_to_exit_with_10,
388        multi_worker_test_base.create_cluster_spec(num_workers=1))
389    mpr.start()
390
391    with self.assertRaisesRegex(
392        multi_process_runner.UnexpectedSubprocessExitError,
393        'Subprocess worker-0 exited with exit code 10'):
394      mpr.join()
395
396  def test_auto_restart(self):
397
398    def fn(counter):
399      counter.value += 1
400      if counter.value == 1:
401        raise ValueError
402
403    manager = multi_process_runner.manager()
404    counter = manager.Value(int, 0)
405    mpr = multi_process_runner.MultiProcessRunner(
406        fn,
407        multi_worker_test_base.create_cluster_spec(num_workers=1),
408        args=(counter,),
409        auto_restart=True)
410    mpr.start()
411    mpr.join()
412    self.assertEqual(counter.value, 2)
413
414  def test_auto_restart_and_timeout(self):
415
416    def fn():
417      logging.info('Running')
418      time.sleep(1)
419      raise ValueError
420
421    mpr = multi_process_runner.MultiProcessRunner(
422        fn,
423        multi_worker_test_base.create_cluster_spec(num_workers=1),
424        auto_restart=True,
425        return_output=True)
426    mpr.start()
427    with self.assertRaises(ValueError) as cm:
428      mpr.join(timeout=10)
429    self.assertGreater(
430        sum(['Running' in msg for msg in cm.exception.mpr_result.stdout]), 1)
431
432  def test_auto_restart_and_chief(self):
433    # If the chief has exited with zero exit code, auto restart should stop
434    # restarting other tasks even if they fail.
435
436    def fn():
437      time.sleep(1)
438      if multi_worker_test_base.get_task_type() != 'chief':
439        raise ValueError
440
441    manager = multi_process_runner.manager()
442    mpr = multi_process_runner.MultiProcessRunner(
443        fn,
444        multi_worker_test_base.create_cluster_spec(
445            has_chief=True, num_workers=1),
446        auto_restart=True)
447    mpr.start()
448    with self.assertRaises(ValueError):
449      mpr.join(timeout=10)
450
451  def test_auto_restart_failure_immediate_after_restart(self):
452    # Test the case when worker-0 fails immediately after worker-1 restarts.
453
454    def fn():
455      time.sleep(5)
456
457    mpr = multi_process_runner.MultiProcessRunner(
458        fn,
459        multi_worker_test_base.create_cluster_spec(
460            has_chief=False, num_workers=2),
461        auto_restart=True)
462    mpr.start()
463    pid = mpr.get_process_id('worker', 1)
464    mpr.terminate('worker', 1)
465    while mpr.get_process_id('worker', 1) == pid:
466      time.sleep(0.1)
467    mpr.terminate('worker', 0)
468    mpr.join(timeout=20)
469
470  def test_auto_restart_terminate(self):
471    # Tasks terminated by the user should also be restarted.
472
473    def fn(counter):
474      counter.value += 1
475      if counter.value == 1:
476        time.sleep(100)
477
478    manager = multi_process_runner.manager()
479    counter = manager.Value(int, 0)
480
481    mpr = multi_process_runner.MultiProcessRunner(
482        fn,
483        multi_worker_test_base.create_cluster_spec(
484            has_chief=False, num_workers=1),
485        args=(counter,),
486        auto_restart=True)
487    mpr.start()
488    time.sleep(3)
489    mpr.terminate('worker', 0)
490    mpr.join(timeout=20)
491    self.assertEqual(counter.value, 2)
492
493  def test_error_reporting_overrides_timeout_reporting(self):
494
495    def fn():
496      if self._worker_idx() == 1:
497        time.sleep(10000)
498      raise ValueError('Worker 0 errored')
499
500    mpr = multi_process_runner.MultiProcessRunner(
501        fn, multi_worker_test_base.create_cluster_spec(num_workers=2))
502    mpr.start()
503
504    with self.assertRaisesRegex(
505        ValueError,
506        'Worker 0 errored'):
507      mpr.join(timeout=20)
508
509  def test_process_exists(self):
510
511    def fn():
512      time.sleep(100000)
513
514    mpr = multi_process_runner.MultiProcessRunner(
515        fn, multi_worker_test_base.create_cluster_spec(num_workers=1))
516    mpr.start()
517    self.assertTrue(mpr.process_exists('worker', 0))
518    mpr.terminate('worker', 0)
519    # Worker 0 should exit at some point, or else the test would time out.
520    while mpr.process_exists('worker', 0):
521      time.sleep(1)
522
523  def test_timeout_none(self):
524
525    if multi_process_runner.is_oss():
526      self.skipTest('Intentionally skipping longer test in OSS.')
527
528    def fn():
529      time.sleep(250)
530      raise ValueError('Worker 0 errored')
531
532    mpr = multi_process_runner.MultiProcessRunner(
533        fn, multi_worker_test_base.create_cluster_spec(num_workers=1))
534
535    mpr.start()
536    with self.assertRaisesRegex(ValueError, 'Worker 0 errored'):
537      mpr.join(timeout=None)
538
539
540_global_pool = multi_process_runner.MultiProcessPoolRunner(
541    multi_worker_test_base.create_cluster_spec(num_workers=2))
542
543
544class MultiProcessPoolRunnerTest(test.TestCase):
545
546  def test_same_process_across_runs(self):
547    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
548    runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
549    pid = runner.run(fn_that_returns_pid)
550    for _ in range(3):
551      self.assertAllEqual(runner.run(fn_that_returns_pid), pid)
552
553  def test_exceptions_in_sub_process(self):
554    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
555    runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
556    pid = runner.run(fn_that_returns_pid)
557    with self.assertRaisesRegex(ValueError, 'This is an error.'):
558      runner.run(fn_that_errors)
559    self.assertAllEqual(runner.run(fn_that_returns_pid), pid)
560
561  def test_tf_config(self):
562    cluster_spec = multi_worker_test_base.create_cluster_spec(
563        has_chief=True, num_workers=2)
564    runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
565    result = runner.run(fn_that_adds_task_type_in_return_data)
566
567    job_count_dict = {'worker': 2, 'chief': 1}
568    for data in result:
569      job_count_dict[data] -= 1
570
571    self.assertEqual(job_count_dict['worker'], 0)
572    self.assertEqual(job_count_dict['chief'], 0)
573
574  @unittest.expectedFailure
575  def test_exception_in_main_process(self):
576    # When there's an exception in the main process, __del__() is not called.
577    # This test is to verify MultiProcessPoolRunner can cope with __del__() not
578    # being called.
579    cluster_spec = multi_worker_test_base.create_cluster_spec(
580        has_chief=True, num_workers=2)
581    runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
582    runner.run(fn_that_returns_pid)
583    raise ValueError('failure')
584
585  def test_initializer(self):
586    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
587    runner = multi_process_runner.MultiProcessPoolRunner(
588        cluster_spec, initializer=lambda: fn_that_sets_global(1))
589    result = runner.run(fn_that_sets_global, args=(2,))
590    self.assertAllEqual(result, [1, 1])
591
592  def test_global_pool(self):
593    _global_pool.run(fn_that_does_nothing)
594
595  def test_nested_pool(self):
596
597    def fn():
598      # This runs in sub processes, so they are each using their own
599      # MultiProcessPoolRunner.
600      _global_pool.run(fn_that_does_nothing)
601
602    _global_pool.run(fn)
603
604
605@combinations.generate(combinations.combine(required_physical_gpus=2))
606class MultiProcessRunnerMultiGPUTest(test.TestCase, parameterized.TestCase):
607
608  def test_not_share_gpu(self):
609    num_gpus = len(context.context().list_physical_devices('GPU'))
610    if num_gpus != 2 and num_gpus != 4:
611      self.skipTest('requires 2 or 4 GPUs')
612    cluster_spec = multi_worker_test_base.create_cluster_spec(
613        has_chief=True, num_workers=1)
614
615    # Verify that CUDA_VISIBLE_DEVICES are different on each worker.
616
617    def cuda_visible_devices_fn():
618      return os.getenv('CUDA_VISIBLE_DEVICES')
619
620    runner = multi_process_runner.MultiProcessRunner(
621        cuda_visible_devices_fn, cluster_spec, share_gpu=False)
622    runner.start()
623    result = runner.join()
624    if num_gpus == 2:
625      self.assertAllEqual(sorted(result.return_value), ['0', '1'])
626    else:
627      self.assertAllEqual(sorted(result.return_value), ['0,2', '1,3'])
628
629    # Verify that CUDA_VISIBLE_DEVICES works.
630
631    def num_gpus_fn():
632      return len(context.context().list_physical_devices('GPU'))
633
634    runner = multi_process_runner.MultiProcessRunner(
635        num_gpus_fn, cluster_spec, share_gpu=False)
636    runner.start()
637    result = runner.join()
638    if num_gpus == 2:
639      self.assertAllEqual(result.return_value, [1, 1])
640    else:
641      self.assertAllEqual(result.return_value, [2, 2])
642
643
644if __name__ == '__main__':
645  multi_process_runner.test_main()
646