1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `multi_process_runner`.""" 16 17import ctypes 18import json 19import os 20import sys 21import threading 22import time 23import unittest 24 25from absl import logging 26from absl.testing import parameterized 27 28from tensorflow.python.distribute import combinations 29from tensorflow.python.distribute import multi_process_runner 30from tensorflow.python.distribute import multi_worker_test_base 31from tensorflow.python.eager import context 32from tensorflow.python.eager import test 33 34 35def fn_that_adds_task_type_in_return_data(): 36 return multi_worker_test_base.get_task_type() 37 38 39def fn_that_errors(): 40 raise ValueError('This is an error.') 41 42 43def fn_that_does_nothing(): 44 pass 45 46 47def fn_that_adds_simple_return_data(): 48 return 'dummy_data' 49 50 51def fn_that_returns_args_and_kwargs(*args, **kwargs): 52 return list(args) + list(kwargs.items()) 53 54 55def fn_with_barrier(): 56 return multi_process_runner.get_barrier() 57 58 59def fn_that_returns_pid(): 60 return os.getpid() 61 62 63V = None 64 65 66def fn_that_sets_global(val): 67 global V 68 old_val = V 69 V = val 70 return old_val 71 72 73@combinations.generate(combinations.combine(required_gpus=0)) 74class MultiProcessRunnerTest(test.TestCase, parameterized.TestCase): 75 76 def _worker_idx(self): 77 config_task = json.loads(os.environ['TF_CONFIG'])['task'] 78 return config_task['index'] 79 80 def test_multi_process_runner(self): 81 mpr_result = multi_process_runner.run( 82 fn_that_adds_task_type_in_return_data, 83 multi_worker_test_base.create_cluster_spec( 84 num_workers=2, num_ps=3, has_chief=True)) 85 86 job_count_dict = {'worker': 2, 'ps': 3, 'chief': 1} 87 for data in mpr_result.return_value: 88 job_count_dict[data] -= 1 89 90 self.assertEqual(job_count_dict['worker'], 0) 91 self.assertEqual(job_count_dict['ps'], 0) 92 self.assertEqual(job_count_dict['chief'], 0) 93 94 def test_multi_process_runner_error_propagates_from_subprocesses(self): 95 runner = multi_process_runner.MultiProcessRunner( 96 fn_that_errors, 97 multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1), 98 max_run_time=20) 99 runner.start() 100 with self.assertRaisesRegex(ValueError, 'This is an error.'): 101 runner.join() 102 103 def test_multi_process_runner_queue_emptied_between_runs(self): 104 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 105 return_value = multi_process_runner.run(fn_that_adds_simple_return_data, 106 cluster_spec).return_value 107 self.assertTrue(return_value) 108 self.assertEqual(return_value[0], 'dummy_data') 109 self.assertEqual(return_value[1], 'dummy_data') 110 return_value = multi_process_runner.run(fn_that_does_nothing, 111 cluster_spec).return_value 112 self.assertFalse(return_value) 113 114 def test_multi_process_runner_args_passed_correctly(self): 115 return_value = multi_process_runner.run( 116 fn_that_returns_args_and_kwargs, 117 multi_worker_test_base.create_cluster_spec(num_workers=1), 118 args=('a', 'b'), 119 kwargs={ 120 'c_k': 'c_v' 121 }).return_value 122 self.assertEqual(return_value[0][0], 'a') 123 self.assertEqual(return_value[0][1], 'b') 124 self.assertEqual(return_value[0][2], ('c_k', 'c_v')) 125 126 def test_stdout_captured(self): 127 128 def simple_print_func(): 129 print('This is something printed.', flush=True) 130 return 'This is returned data.' 131 132 mpr_result = multi_process_runner.run( 133 simple_print_func, 134 multi_worker_test_base.create_cluster_spec(num_workers=2), 135 return_output=True) 136 std_stream_results = mpr_result.stdout 137 return_value = mpr_result.return_value 138 self.assertIn('[worker-0]: This is something printed.\n', 139 std_stream_results) 140 self.assertIn('[worker-1]: This is something printed.\n', 141 std_stream_results) 142 self.assertIn('This is returned data.', return_value) 143 144 def test_termination(self): 145 146 def fn(): 147 for i in range(0, 10): 148 print( 149 'index {}, iteration {}'.format(self._worker_idx(), i), flush=True) 150 time.sleep(5) 151 152 mpr = multi_process_runner.MultiProcessRunner( 153 fn, 154 multi_worker_test_base.create_cluster_spec(num_workers=2), 155 return_output=True) 156 mpr.start() 157 time.sleep(5) 158 mpr.terminate('worker', 0) 159 160 std_stream_results = mpr.join().stdout 161 162 # Worker 0 is terminated in the middle, so it should not have iteration 9 163 # printed. 164 self.assertIn('[worker-0]: index 0, iteration 0\n', std_stream_results) 165 self.assertNotIn('[worker-0]: index 0, iteration 9\n', 166 std_stream_results) 167 self.assertIn('[worker-1]: index 1, iteration 0\n', std_stream_results) 168 self.assertIn('[worker-1]: index 1, iteration 9\n', std_stream_results) 169 170 def test_termination_and_start_single_process(self): 171 172 def fn(): 173 for i in range(0, 10): 174 print( 175 'index {}, iteration {}'.format(self._worker_idx(), i), flush=True) 176 time.sleep(1) 177 178 mpr = multi_process_runner.MultiProcessRunner( 179 fn, 180 multi_worker_test_base.create_cluster_spec(num_workers=2), 181 return_output=True) 182 mpr.start() 183 time.sleep(3) 184 mpr.terminate('worker', 0) 185 mpr.start_single_process('worker', 0) 186 std_stream_results = mpr.join().stdout 187 188 # Worker 0 is terminated in the middle, but a new worker 0 is added, so it 189 # should still have iteration 9 printed. Moreover, iteration 0 of worker 0 190 # should happen twice. 191 self.assertLen( 192 [s for s in std_stream_results if 'index 0, iteration 0' in s], 2) 193 self.assertIn('[worker-0]: index 0, iteration 9\n', std_stream_results) 194 self.assertIn('[worker-1]: index 1, iteration 0\n', std_stream_results) 195 self.assertIn('[worker-1]: index 1, iteration 9\n', std_stream_results) 196 197 def test_streaming(self): 198 199 def fn(): 200 for i in range(5): 201 logging.info('(logging) %s-%d, i: %d', 202 multi_worker_test_base.get_task_type(), self._worker_idx(), 203 i) 204 print( 205 '(print) {}-{}, i: {}'.format( 206 multi_worker_test_base.get_task_type(), self._worker_idx(), i), 207 flush=True) 208 time.sleep(1) 209 210 mpr = multi_process_runner.MultiProcessRunner( 211 fn, 212 multi_worker_test_base.create_cluster_spec( 213 has_chief=True, num_workers=2, num_ps=2), 214 return_output=True) 215 mpr._dependence_on_chief = False 216 217 mpr.start() 218 mpr.start_single_process('worker', 2) 219 mpr.start_single_process('ps', 2) 220 mpr_result = mpr.join() 221 222 list_to_assert = mpr_result.stdout 223 224 for job in ['chief']: 225 for iteration in range(5): 226 self.assertTrue( 227 any('(logging) {}-0, i: {}'.format(job, iteration) in line 228 for line in list_to_assert)) 229 self.assertTrue( 230 any('(print) {}-0, i: {}'.format(job, iteration) in line 231 for line in list_to_assert)) 232 233 for job in ['worker', 'ps']: 234 for iteration in range(5): 235 for task in range(3): 236 self.assertTrue( 237 any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line 238 for line in list_to_assert)) 239 self.assertTrue( 240 any('(print) {}-{}, i: {}'.format(job, task, iteration) in line 241 for line in list_to_assert)) 242 task = 3 243 self.assertFalse( 244 any('(logging) {}-{}, i: {}'.format(job, task, iteration) in line 245 for line in list_to_assert)) 246 self.assertFalse( 247 any('(print) {}-{}, i: {}'.format(job, task, iteration) in line 248 for line in list_to_assert)) 249 250 def test_start_in_process_as(self): 251 252 def fn(): 253 for i in range(5): 254 logging.info('%s-%d, i: %d', multi_worker_test_base.get_task_type(), 255 self._worker_idx(), i) 256 time.sleep(1) 257 258 mpr = multi_process_runner.MultiProcessRunner( 259 fn, 260 multi_worker_test_base.create_cluster_spec( 261 has_chief=True, num_workers=1), 262 return_output=True) 263 264 def eval_func(): 265 time.sleep(1) 266 mpr.start_single_process(task_type='evaluator', task_id=0) 267 268 eval_thread = threading.Thread(target=eval_func) 269 eval_thread.start() 270 mpr.start_in_process_as(as_task_type='chief', as_task_id=0) 271 eval_thread.join() 272 list_to_assert = mpr.join().stdout 273 for job in ['worker', 'evaluator']: 274 for iteration in range(5): 275 self.assertTrue( 276 any('{}-0, i: {}'.format(job, iteration) in line 277 for line in list_to_assert)) 278 279 def test_terminate_all_does_not_ignore_error(self): 280 mpr = multi_process_runner.MultiProcessRunner( 281 fn_that_errors, 282 multi_worker_test_base.create_cluster_spec(num_workers=2), 283 return_output=True) 284 mpr.start() 285 time.sleep(60) 286 mpr.terminate_all() 287 with self.assertRaisesRegex(ValueError, 'This is an error.'): 288 mpr.join() 289 290 def test_barrier(self): 291 multi_process_runner.run( 292 fn_with_barrier, 293 cluster_spec=multi_worker_test_base.create_cluster_spec( 294 has_chief=True, num_workers=1), 295 ) 296 297 def test_barrier_called_in_main_process(self): 298 with self.assertRaises(ValueError): 299 multi_process_runner.get_barrier() 300 301 def test_stdout_available_when_timeout(self): 302 303 def fn(): 304 logging.info('something printed') 305 time.sleep(10000) # Intentionally make the test timeout. 306 307 with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm: 308 mpr = multi_process_runner.MultiProcessRunner( 309 fn, 310 multi_worker_test_base.create_cluster_spec(num_workers=1), 311 return_output=True) 312 mpr.start() 313 mpr.join(timeout=60) 314 mpr.terminate_all() 315 316 list_to_assert = cm.exception.mpr_result.stdout 317 self.assertTrue( 318 any('something printed' in line for line in list_to_assert)) 319 320 def test_seg_fault_raises_error(self): 321 322 if multi_process_runner.is_oss() or sys.version_info >= (3, 7): 323 self.skipTest('TODO(b/171004637): Failing in OSS and Python 3.7+') 324 325 def fn_expected_to_seg_fault(): 326 ctypes.string_at(0) # Intentionally made seg fault. 327 328 with self.assertRaises( 329 multi_process_runner.UnexpectedSubprocessExitError) as cm: 330 multi_process_runner.run( 331 fn_expected_to_seg_fault, 332 multi_worker_test_base.create_cluster_spec(num_workers=1), 333 return_output=True) 334 self.assertIn('Subprocess worker-0 exited with exit code', 335 str(cm.exception)) 336 list_to_assert = cm.exception.mpr_result.stdout 337 self.assertTrue( 338 any('Segmentation fault' in line for line in list_to_assert)) 339 340 def test_seg_fault_in_chief_raises_error(self): 341 342 if multi_process_runner.is_oss() or sys.version_info >= (3, 7): 343 self.skipTest('TODO(b/171004637): Failing in OSS and Python 3.7+') 344 345 def fn_expected_to_seg_fault(): 346 if multi_worker_test_base.get_task_type() == 'worker': 347 time.sleep(10000) 348 ctypes.string_at(0) # Intentionally made seg fault. 349 350 with self.assertRaises( 351 multi_process_runner.UnexpectedSubprocessExitError) as cm: 352 multi_process_runner.run( 353 fn_expected_to_seg_fault, 354 multi_worker_test_base.create_cluster_spec( 355 has_chief=True, num_workers=1), 356 return_output=True) 357 self.assertIn('Subprocess chief-0 exited with exit code', 358 str(cm.exception)) 359 list_to_assert = cm.exception.mpr_result.stdout 360 self.assertTrue( 361 any('Segmentation fault' in line for line in list_to_assert)) 362 363 def test_exit_code_is_reported_by_chief_subprocess(self): 364 365 def fn_expected_to_exit_with_20(): 366 if multi_worker_test_base.get_task_type() == 'worker': 367 time.sleep(10000) 368 sys.exit(20) 369 370 mpr = multi_process_runner.MultiProcessRunner( 371 fn_expected_to_exit_with_20, 372 multi_worker_test_base.create_cluster_spec( 373 has_chief=True, num_workers=1)) 374 mpr.start() 375 376 with self.assertRaisesRegex( 377 multi_process_runner.UnexpectedSubprocessExitError, 378 'Subprocess chief-0 exited with exit code 20'): 379 mpr.join() 380 381 def test_exit_code_is_reported_by_subprocess(self): 382 383 def fn_expected_to_exit_with_10(): 384 sys.exit(10) 385 386 mpr = multi_process_runner.MultiProcessRunner( 387 fn_expected_to_exit_with_10, 388 multi_worker_test_base.create_cluster_spec(num_workers=1)) 389 mpr.start() 390 391 with self.assertRaisesRegex( 392 multi_process_runner.UnexpectedSubprocessExitError, 393 'Subprocess worker-0 exited with exit code 10'): 394 mpr.join() 395 396 def test_auto_restart(self): 397 398 def fn(counter): 399 counter.value += 1 400 if counter.value == 1: 401 raise ValueError 402 403 manager = multi_process_runner.manager() 404 counter = manager.Value(int, 0) 405 mpr = multi_process_runner.MultiProcessRunner( 406 fn, 407 multi_worker_test_base.create_cluster_spec(num_workers=1), 408 args=(counter,), 409 auto_restart=True) 410 mpr.start() 411 mpr.join() 412 self.assertEqual(counter.value, 2) 413 414 def test_auto_restart_and_timeout(self): 415 416 def fn(): 417 logging.info('Running') 418 time.sleep(1) 419 raise ValueError 420 421 mpr = multi_process_runner.MultiProcessRunner( 422 fn, 423 multi_worker_test_base.create_cluster_spec(num_workers=1), 424 auto_restart=True, 425 return_output=True) 426 mpr.start() 427 with self.assertRaises(ValueError) as cm: 428 mpr.join(timeout=10) 429 self.assertGreater( 430 sum(['Running' in msg for msg in cm.exception.mpr_result.stdout]), 1) 431 432 def test_auto_restart_and_chief(self): 433 # If the chief has exited with zero exit code, auto restart should stop 434 # restarting other tasks even if they fail. 435 436 def fn(): 437 time.sleep(1) 438 if multi_worker_test_base.get_task_type() != 'chief': 439 raise ValueError 440 441 manager = multi_process_runner.manager() 442 mpr = multi_process_runner.MultiProcessRunner( 443 fn, 444 multi_worker_test_base.create_cluster_spec( 445 has_chief=True, num_workers=1), 446 auto_restart=True) 447 mpr.start() 448 with self.assertRaises(ValueError): 449 mpr.join(timeout=10) 450 451 def test_auto_restart_failure_immediate_after_restart(self): 452 # Test the case when worker-0 fails immediately after worker-1 restarts. 453 454 def fn(): 455 time.sleep(5) 456 457 mpr = multi_process_runner.MultiProcessRunner( 458 fn, 459 multi_worker_test_base.create_cluster_spec( 460 has_chief=False, num_workers=2), 461 auto_restart=True) 462 mpr.start() 463 pid = mpr.get_process_id('worker', 1) 464 mpr.terminate('worker', 1) 465 while mpr.get_process_id('worker', 1) == pid: 466 time.sleep(0.1) 467 mpr.terminate('worker', 0) 468 mpr.join(timeout=20) 469 470 def test_auto_restart_terminate(self): 471 # Tasks terminated by the user should also be restarted. 472 473 def fn(counter): 474 counter.value += 1 475 if counter.value == 1: 476 time.sleep(100) 477 478 manager = multi_process_runner.manager() 479 counter = manager.Value(int, 0) 480 481 mpr = multi_process_runner.MultiProcessRunner( 482 fn, 483 multi_worker_test_base.create_cluster_spec( 484 has_chief=False, num_workers=1), 485 args=(counter,), 486 auto_restart=True) 487 mpr.start() 488 time.sleep(3) 489 mpr.terminate('worker', 0) 490 mpr.join(timeout=20) 491 self.assertEqual(counter.value, 2) 492 493 def test_error_reporting_overrides_timeout_reporting(self): 494 495 def fn(): 496 if self._worker_idx() == 1: 497 time.sleep(10000) 498 raise ValueError('Worker 0 errored') 499 500 mpr = multi_process_runner.MultiProcessRunner( 501 fn, multi_worker_test_base.create_cluster_spec(num_workers=2)) 502 mpr.start() 503 504 with self.assertRaisesRegex( 505 ValueError, 506 'Worker 0 errored'): 507 mpr.join(timeout=20) 508 509 def test_process_exists(self): 510 511 def fn(): 512 time.sleep(100000) 513 514 mpr = multi_process_runner.MultiProcessRunner( 515 fn, multi_worker_test_base.create_cluster_spec(num_workers=1)) 516 mpr.start() 517 self.assertTrue(mpr.process_exists('worker', 0)) 518 mpr.terminate('worker', 0) 519 # Worker 0 should exit at some point, or else the test would time out. 520 while mpr.process_exists('worker', 0): 521 time.sleep(1) 522 523 def test_timeout_none(self): 524 525 if multi_process_runner.is_oss(): 526 self.skipTest('Intentionally skipping longer test in OSS.') 527 528 def fn(): 529 time.sleep(250) 530 raise ValueError('Worker 0 errored') 531 532 mpr = multi_process_runner.MultiProcessRunner( 533 fn, multi_worker_test_base.create_cluster_spec(num_workers=1)) 534 535 mpr.start() 536 with self.assertRaisesRegex(ValueError, 'Worker 0 errored'): 537 mpr.join(timeout=None) 538 539 540_global_pool = multi_process_runner.MultiProcessPoolRunner( 541 multi_worker_test_base.create_cluster_spec(num_workers=2)) 542 543 544class MultiProcessPoolRunnerTest(test.TestCase): 545 546 def test_same_process_across_runs(self): 547 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 548 runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) 549 pid = runner.run(fn_that_returns_pid) 550 for _ in range(3): 551 self.assertAllEqual(runner.run(fn_that_returns_pid), pid) 552 553 def test_exceptions_in_sub_process(self): 554 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 555 runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) 556 pid = runner.run(fn_that_returns_pid) 557 with self.assertRaisesRegex(ValueError, 'This is an error.'): 558 runner.run(fn_that_errors) 559 self.assertAllEqual(runner.run(fn_that_returns_pid), pid) 560 561 def test_tf_config(self): 562 cluster_spec = multi_worker_test_base.create_cluster_spec( 563 has_chief=True, num_workers=2) 564 runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) 565 result = runner.run(fn_that_adds_task_type_in_return_data) 566 567 job_count_dict = {'worker': 2, 'chief': 1} 568 for data in result: 569 job_count_dict[data] -= 1 570 571 self.assertEqual(job_count_dict['worker'], 0) 572 self.assertEqual(job_count_dict['chief'], 0) 573 574 @unittest.expectedFailure 575 def test_exception_in_main_process(self): 576 # When there's an exception in the main process, __del__() is not called. 577 # This test is to verify MultiProcessPoolRunner can cope with __del__() not 578 # being called. 579 cluster_spec = multi_worker_test_base.create_cluster_spec( 580 has_chief=True, num_workers=2) 581 runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) 582 runner.run(fn_that_returns_pid) 583 raise ValueError('failure') 584 585 def test_initializer(self): 586 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 587 runner = multi_process_runner.MultiProcessPoolRunner( 588 cluster_spec, initializer=lambda: fn_that_sets_global(1)) 589 result = runner.run(fn_that_sets_global, args=(2,)) 590 self.assertAllEqual(result, [1, 1]) 591 592 def test_global_pool(self): 593 _global_pool.run(fn_that_does_nothing) 594 595 def test_nested_pool(self): 596 597 def fn(): 598 # This runs in sub processes, so they are each using their own 599 # MultiProcessPoolRunner. 600 _global_pool.run(fn_that_does_nothing) 601 602 _global_pool.run(fn) 603 604 605@combinations.generate(combinations.combine(required_physical_gpus=2)) 606class MultiProcessRunnerMultiGPUTest(test.TestCase, parameterized.TestCase): 607 608 def test_not_share_gpu(self): 609 num_gpus = len(context.context().list_physical_devices('GPU')) 610 if num_gpus != 2 and num_gpus != 4: 611 self.skipTest('requires 2 or 4 GPUs') 612 cluster_spec = multi_worker_test_base.create_cluster_spec( 613 has_chief=True, num_workers=1) 614 615 # Verify that CUDA_VISIBLE_DEVICES are different on each worker. 616 617 def cuda_visible_devices_fn(): 618 return os.getenv('CUDA_VISIBLE_DEVICES') 619 620 runner = multi_process_runner.MultiProcessRunner( 621 cuda_visible_devices_fn, cluster_spec, share_gpu=False) 622 runner.start() 623 result = runner.join() 624 if num_gpus == 2: 625 self.assertAllEqual(sorted(result.return_value), ['0', '1']) 626 else: 627 self.assertAllEqual(sorted(result.return_value), ['0,2', '1,3']) 628 629 # Verify that CUDA_VISIBLE_DEVICES works. 630 631 def num_gpus_fn(): 632 return len(context.context().list_physical_devices('GPU')) 633 634 runner = multi_process_runner.MultiProcessRunner( 635 num_gpus_fn, cluster_spec, share_gpu=False) 636 runner.start() 637 result = runner.join() 638 if num_gpus == 2: 639 self.assertAllEqual(result.return_value, [1, 1]) 640 else: 641 self.assertAllEqual(result.return_value, [2, 2]) 642 643 644if __name__ == '__main__': 645 multi_process_runner.test_main() 646