1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import copy 10import os 11import pickle 12import socket 13import threading 14import time 15from abc import ABC, abstractmethod 16from base64 import b64encode 17from datetime import datetime, timedelta 18from typing import Callable, cast, Optional, Tuple 19from unittest import TestCase 20from unittest.mock import call, MagicMock, Mock, patch, PropertyMock 21 22import torch.distributed as dist 23from torch.distributed import HashStore, Store 24from torch.distributed.elastic.rendezvous import ( 25 RendezvousClosedError, 26 RendezvousError, 27 RendezvousInfo, 28 RendezvousParameters, 29 RendezvousStateError, 30 RendezvousStoreInfo, 31 RendezvousTimeoutError, 32) 33from torch.distributed.elastic.rendezvous.dynamic_rendezvous import ( 34 _Action, 35 _BackendRendezvousStateHolder, 36 _DistributedRendezvousOpExecutor, 37 _NodeDesc, 38 _NodeDescGenerator, 39 _RendezvousCloseOp, 40 _RendezvousContext, 41 _RendezvousExitOp, 42 _RendezvousJoinOp, 43 _RendezvousKeepAliveOp, 44 _RendezvousState, 45 _RendezvousStateHolder, 46 create_handler, 47 DynamicRendezvousHandler, 48 RendezvousBackend, 49 RendezvousSettings, 50 RendezvousTimeout, 51 Token, 52) 53 54 55class CustomAssertMixin: 56 assertDictEqual: Callable 57 58 def assert_state_equal( 59 self, actual: _RendezvousState, expected: _RendezvousState 60 ) -> None: 61 self.assertDictEqual(vars(actual), vars(expected)) 62 63 def assert_state_empty(self, actual: _RendezvousState) -> None: 64 self.assertDictEqual(vars(actual), vars(_RendezvousState())) 65 66 67class RendezvousTimeoutTest(TestCase): 68 def test_init_initializes_timeout(self) -> None: 69 timeout = RendezvousTimeout( 70 timedelta(seconds=50), 71 timedelta(seconds=60), 72 timedelta(seconds=70), 73 timedelta(seconds=80), 74 ) 75 76 self.assertEqual(timeout.join, timedelta(seconds=50)) 77 self.assertEqual(timeout.last_call, timedelta(seconds=60)) 78 self.assertEqual(timeout.close, timedelta(seconds=70)) 79 self.assertEqual(timeout.heartbeat, timedelta(seconds=80)) 80 81 def test_init_initializes_timeout_if_no_timeout_is_specified(self) -> None: 82 timeout = RendezvousTimeout() 83 84 self.assertEqual(timeout.join, timedelta(seconds=600)) 85 self.assertEqual(timeout.last_call, timedelta(seconds=30)) 86 self.assertEqual(timeout.close, timedelta(seconds=30)) 87 self.assertEqual(timeout.heartbeat, timedelta(seconds=5)) 88 89 def test_init_raises_error_if_timeout_is_not_positive(self) -> None: 90 join_timeouts = [timedelta(seconds=0), timedelta(seconds=-1)] 91 92 for join_timeout in join_timeouts: 93 with self.subTest(join_timeout=join_timeout): 94 with self.assertRaisesRegex( 95 ValueError, 96 rf"^The join timeout \({join_timeout}\) must be positive.$", 97 ): 98 timeout = RendezvousTimeout(join_timeout) 99 100 101class NodeDescTest(TestCase): 102 def test_repr(self) -> None: 103 desc = _NodeDesc("dummy_fqdn", 3, 5) 104 105 self.assertEqual(repr(desc), "dummy_fqdn_3_5") 106 107 def test_hash(self) -> None: 108 desc1 = _NodeDesc("dummy_fqdn", 2, 4) 109 desc2 = _NodeDesc("dummy_fqdn", 3, 5) 110 111 descs = {desc1, desc2} 112 113 self.assertIn(desc1, descs) 114 self.assertIn(desc2, descs) 115 116 117class NodeDescGeneratorTest(TestCase): 118 def test_generate(self) -> None: 119 desc_generator = _NodeDescGenerator() 120 121 fqdn = socket.getfqdn() 122 123 pid = os.getpid() 124 125 for local_id in range(4): 126 with self.subTest(fqdn=fqdn, pid=pid, local_id=local_id): 127 desc = desc_generator.generate() 128 129 self.assertEqual(repr(desc), f"{fqdn}_{pid}_{local_id}") 130 131 132class RendezvousStateTest(TestCase): 133 def test_encoded_size_is_within_expected_limit(self) -> None: 134 state = _RendezvousState() 135 state.round = 1 136 state.complete = True 137 state.deadline = datetime.utcnow() 138 state.closed = True 139 140 # fmt: off 141 expected_max_sizes = ( 142 ( 5, 2 * (2 ** 10),), # 10 machines <= 2KB # noqa: E201, E241, E262 143 ( 50, 16 * (2 ** 10),), # 100 machines <= 16KB # noqa: E201, E241, E262 144 ( 500, 160 * (2 ** 10),), # 1000 machines <= 160KB # noqa: E201, E241, E262 145 (5000, 1600 * (2 ** 10),), # 10000 machines <= 1.6MB # noqa: E201, E241, E262 146 ) 147 # fmt: on 148 149 for num_nodes, max_byte_size in expected_max_sizes: 150 with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size): 151 for i in range(num_nodes): 152 node_running = _NodeDesc( 153 f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i 154 ) 155 node_waiting = _NodeDesc( 156 f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i 157 ) 158 159 state.participants[node_running] = i 160 161 state.wait_list.add(node_waiting) 162 163 state.last_heartbeats[node_running] = datetime.utcnow() 164 state.last_heartbeats[node_waiting] = datetime.utcnow() 165 166 bits = pickle.dumps(state) 167 168 base64_bits = b64encode(bits) 169 170 self.assertLessEqual(len(base64_bits), max_byte_size) 171 172 173class FakeRendezvousBackend(RendezvousBackend): 174 _state: Optional[bytes] 175 _token: int 176 177 def __init__(self) -> None: 178 self._state = None 179 self._token = 0 180 181 @property 182 def name(self) -> str: 183 return "fake_backend" 184 185 def get_state(self) -> Optional[Tuple[bytes, Token]]: 186 if self._token == 0: 187 return None 188 189 return self._state, self._token # type: ignore[return-value] 190 191 def set_state( 192 self, state: bytes, token: Optional[Token] = None 193 ) -> Optional[Tuple[bytes, Token, bool]]: 194 if token is None: 195 token = 0 196 197 if token == self._token: 198 self._state = state 199 self._token += 1 200 201 has_set = True 202 else: 203 has_set = False 204 205 return self._state, self._token, has_set # type: ignore[return-value] 206 207 def get_state_internal(self) -> _RendezvousState: 208 return pickle.loads(cast(bytes, self._state)) 209 210 def set_state_internal(self, state: _RendezvousState) -> None: 211 self._state = pickle.dumps(state) 212 self._token += 1 213 214 def corrupt_state(self) -> None: 215 self._state = b"corrupt_state" 216 self._token += 1 217 218 219class BackendRendezvousStateHolderTest(TestCase, CustomAssertMixin): 220 def setUp(self) -> None: 221 self._backend = FakeRendezvousBackend() 222 223 mock_get_state = MagicMock(wraps=self._backend.get_state) 224 mock_set_state = MagicMock(wraps=self._backend.set_state) 225 226 self._mock_backend = Mock() 227 self._mock_backend.get_state = mock_get_state 228 self._mock_backend.set_state = mock_set_state 229 230 setattr(self._backend, "get_state", mock_get_state) # noqa: B010 231 setattr(self._backend, "set_state", mock_set_state) # noqa: B010 232 233 self._settings = RendezvousSettings( 234 run_id="dummy_run_id", 235 min_nodes=1, 236 max_nodes=1, 237 timeout=RendezvousTimeout(), 238 keep_alive_interval=timedelta(seconds=30), 239 keep_alive_max_attempt=3, 240 ) 241 242 self._cache_duration = 0 243 244 self._now = datetime(2000, 1, 1, hour=0, minute=0) 245 246 self._datetime_patch = patch( 247 "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime" 248 ) 249 250 mock_datetime = self._datetime_patch.start() 251 mock_datetime.utcnow.return_value = self._now 252 253 def tearDown(self) -> None: 254 self._datetime_patch.stop() 255 256 def _create_state(self) -> _RendezvousState: 257 state = _RendezvousState() 258 state.round = 999 259 state.complete = True 260 state.deadline = self._now 261 state.closed = True 262 state.participants = { 263 _NodeDesc("dummy1", 1, 1): 0, 264 _NodeDesc("dummy2", 1, 1): 1, 265 _NodeDesc("dummy3", 1, 1): 2, 266 } 267 state.wait_list = { 268 _NodeDesc("dummy4", 1, 1), 269 _NodeDesc("dummy5", 1, 1), 270 } 271 state.last_heartbeats = { 272 _NodeDesc("dummy1", 1, 1): self._now, 273 _NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15), 274 _NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30), 275 _NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60), 276 _NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90), 277 } 278 279 return state 280 281 def _create_state_holder(self) -> _BackendRendezvousStateHolder: 282 return _BackendRendezvousStateHolder( 283 self._backend, self._settings, self._cache_duration 284 ) 285 286 def test_init_initializes_state_holder(self) -> None: 287 state_holder = self._create_state_holder() 288 289 self.assert_state_empty(state_holder.state) 290 291 self._mock_backend.assert_not_called() 292 293 def test_sync_gets_empty_state_if_backend_state_does_not_exist(self) -> None: 294 state_holder = self._create_state_holder() 295 296 has_set = state_holder.sync() 297 298 self.assertIsNone(has_set) 299 300 self.assert_state_empty(state_holder.state) 301 302 self.assertEqual(self._mock_backend.get_state.call_count, 1) 303 self.assertEqual(self._mock_backend.set_state.call_count, 0) 304 305 def test_sync_gets_backend_state_if_local_state_is_clean(self) -> None: 306 state_holder = self._create_state_holder() 307 308 expected_state = self._create_state() 309 310 for attempt in range(1, 4): 311 with self.subTest(attempt=attempt): 312 expected_state.round = attempt 313 314 self._backend.set_state_internal(expected_state) 315 316 has_set = state_holder.sync() 317 318 self.assertIsNone(has_set) 319 320 self.assert_state_equal(state_holder.state, expected_state) 321 322 self.assertEqual(self._mock_backend.get_state.call_count, 1) 323 self.assertEqual(self._mock_backend.set_state.call_count, 0) 324 325 self._mock_backend.reset_mock() 326 327 def test_sync_gets_backend_state_if_local_state_is_old_and_dirty(self) -> None: 328 state_holder = self._create_state_holder() 329 330 expected_state = self._create_state() 331 332 for attempt in range(1, 4): 333 with self.subTest(attempt=attempt): 334 self._backend.set_state_internal(expected_state) # Increment token. 335 336 state_holder.state.round = attempt 337 state_holder.mark_dirty() 338 339 has_set = state_holder.sync() 340 341 self.assertFalse(has_set) 342 343 self.assert_state_equal(state_holder.state, expected_state) 344 345 self.assertEqual(self._mock_backend.get_state.call_count, 0) 346 self.assertEqual(self._mock_backend.set_state.call_count, 1) 347 348 self._mock_backend.reset_mock() 349 350 def test_sync_sets_backend_state_if_local_state_is_new_and_dirty(self) -> None: 351 state_holder = self._create_state_holder() 352 353 for attempt in range(1, 4): 354 with self.subTest(attempt=attempt): 355 state_holder.state.round = attempt 356 state_holder.mark_dirty() 357 358 has_set = state_holder.sync() 359 360 self.assertTrue(has_set) 361 362 expected_state = self._backend.get_state_internal() 363 364 self.assert_state_equal(state_holder.state, expected_state) 365 366 self.assertEqual(self._mock_backend.get_state.call_count, 0) 367 self.assertEqual(self._mock_backend.set_state.call_count, 1) 368 369 self._mock_backend.reset_mock() 370 371 def test_sync_uses_cached_state_if_cache_duration_is_specified(self) -> None: 372 state = self._create_state() 373 374 self._backend.set_state_internal(state) 375 376 with patch( 377 "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time" 378 ) as mock_time: 379 for cache_duration in [1, 5, 10]: 380 with self.subTest(cache_duration=cache_duration): 381 self._cache_duration = cache_duration 382 383 state_holder = self._create_state_holder() 384 385 mock_time.monotonic.return_value = 5 386 387 state_holder.sync() 388 389 has_set = state_holder.sync() 390 391 self.assertIsNone(has_set) 392 393 self.assertEqual(self._mock_backend.get_state.call_count, 1) 394 self.assertEqual(self._mock_backend.set_state.call_count, 0) 395 396 mock_time.monotonic.return_value = 5 + self._cache_duration 397 398 state_holder.sync() 399 400 has_set = state_holder.sync() 401 402 self.assertIsNone(has_set) 403 404 self.assertEqual(self._mock_backend.get_state.call_count, 1) 405 self.assertEqual(self._mock_backend.set_state.call_count, 0) 406 407 self._mock_backend.get_state.reset_mock() 408 409 def test_sync_gets_backend_state_if_cached_state_has_expired(self) -> None: 410 state = self._create_state() 411 412 self._backend.set_state_internal(state) 413 414 with patch( 415 "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time" 416 ) as mock_time: 417 self._cache_duration = 1 418 419 state_holder = self._create_state_holder() 420 421 mock_time.monotonic.return_value = 5 422 423 state_holder.sync() 424 425 has_set = state_holder.sync() 426 427 self.assertIsNone(has_set) 428 429 self.assertEqual(self._mock_backend.get_state.call_count, 1) 430 self.assertEqual(self._mock_backend.set_state.call_count, 0) 431 432 mock_time.monotonic.return_value = 5 + self._cache_duration + 0.01 433 434 state_holder.sync() 435 436 has_set = state_holder.sync() 437 438 self.assertIsNone(has_set) 439 440 self.assertEqual(self._mock_backend.get_state.call_count, 2) 441 self.assertEqual(self._mock_backend.set_state.call_count, 0) 442 443 def test_sync_sanitizes_state(self) -> None: 444 state = self._create_state() 445 446 expected_state = copy.deepcopy(state) 447 448 dead_node1 = _NodeDesc("dead1", 1, 1) 449 dead_node2 = _NodeDesc("dead2", 1, 1) 450 dead_node3 = _NodeDesc("dead3", 1, 1) 451 dead_node4 = _NodeDesc("dead4", 1, 1) 452 dead_node5 = _NodeDesc("dead5", 1, 1) 453 454 state.last_heartbeats[dead_node1] = self._now - timedelta(seconds=91) 455 state.last_heartbeats[dead_node2] = self._now - timedelta(seconds=100) 456 state.last_heartbeats[dead_node3] = self._now - timedelta(seconds=110) 457 state.last_heartbeats[dead_node4] = self._now - timedelta(seconds=120) 458 state.last_heartbeats[dead_node5] = self._now - timedelta(seconds=130) 459 460 state.participants[dead_node1] = 0 461 state.participants[dead_node2] = 0 462 state.participants[dead_node3] = 0 463 464 state.wait_list.add(dead_node4) 465 state.wait_list.add(dead_node5) 466 467 self._backend.set_state_internal(state) 468 469 state_holder = self._create_state_holder() 470 471 state_holder.sync() 472 473 self.assert_state_equal(state_holder.state, expected_state) 474 475 def test_sync_sanitizes_state_if_no_participants_is_left(self) -> None: 476 state = self._create_state() 477 478 expected_state = copy.deepcopy(state) 479 480 for node in state.last_heartbeats: 481 state.last_heartbeats[node] = self._now - timedelta(seconds=100) 482 483 expected_state.complete = False 484 expected_state.round = 1000 485 expected_state.participants = {} 486 expected_state.wait_list = set() 487 expected_state.last_heartbeats = {} 488 489 self._backend.set_state_internal(state) 490 491 state_holder = self._create_state_holder() 492 493 state_holder.sync() 494 495 self.assert_state_equal(state_holder.state, expected_state) 496 497 def test_sync_raises_error_if_backend_state_is_corrupt(self) -> None: 498 self._backend.corrupt_state() 499 500 state_holder = self._create_state_holder() 501 502 with self.assertRaisesRegex( 503 RendezvousStateError, 504 r"^The rendezvous state is corrupt. See inner exception for details.$", 505 ): 506 state_holder.sync() 507 508 509class FakeRendezvousStateHolder(_RendezvousStateHolder): 510 _state: _RendezvousState 511 _dirty: Optional[bool] 512 513 def __init__(self) -> None: 514 self._state = _RendezvousState() 515 self._dirty = None 516 517 @property 518 def state(self) -> _RendezvousState: 519 return self._state 520 521 @state.setter 522 def state(self, value) -> None: 523 self._state = value 524 525 def sync(self) -> Optional[bool]: 526 self._dirty, dirty = None, self._dirty 527 528 return dirty 529 530 def mark_dirty(self) -> None: 531 self._dirty = True 532 533 534class DistributedRendezvousOpExecutorTest(TestCase, CustomAssertMixin): 535 def setUp(self) -> None: 536 self._node = _NodeDesc("this_node", 1, 1) 537 538 self._state_holder = FakeRendezvousStateHolder() 539 540 mock_sync = MagicMock(wraps=self._state_holder.sync) 541 mock_mark = MagicMock(wraps=self._state_holder.mark_dirty) 542 543 self._mock_state_holder = Mock() 544 self._mock_state_holder.sync = mock_sync 545 self._mock_state_holder.mark = mock_mark 546 547 setattr(self._state_holder, "sync", mock_sync) # noqa: B010 548 setattr(self._state_holder, "mark_dirty", mock_mark) # noqa: B010 549 550 self._state = self._state_holder.state 551 552 self._min_nodes = 1 553 self._max_nodes = 1 554 555 self._timeout = RendezvousTimeout() 556 557 self._now = datetime(2000, 1, 1, hour=0, minute=0) 558 559 self._datetime_patch = patch( 560 "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime" 561 ) 562 563 mock_datetime = self._datetime_patch.start() 564 mock_datetime.utcnow.return_value = self._now 565 566 def tearDown(self) -> None: 567 self._datetime_patch.stop() 568 569 def _create_settings(self) -> RendezvousSettings: 570 return RendezvousSettings( 571 run_id="dummy_run_id", 572 min_nodes=self._min_nodes, 573 max_nodes=self._max_nodes, 574 timeout=self._timeout, 575 keep_alive_interval=timedelta(seconds=30), 576 keep_alive_max_attempt=3, 577 ) 578 579 def _create_op_executor( 580 self, settings: Optional[RendezvousSettings] = None 581 ) -> _DistributedRendezvousOpExecutor: 582 self._state_holder.state = self._state 583 584 if settings is None: 585 settings = self._create_settings() 586 587 return _DistributedRendezvousOpExecutor( 588 self._node, self._state_holder, settings 589 ) 590 591 def _run_action(self, action: _Action) -> None: 592 op_executor = self._create_op_executor() 593 594 op = MagicMock(side_effect=[action, _Action.FINISH]) 595 596 op_executor.run(op, deadline=1) 597 598 def _assert_action(self, action: _Action, expected_state: _RendezvousState) -> None: 599 self._run_action(action) 600 601 self.assert_state_equal(self._state, expected_state) 602 603 self.assertListEqual( 604 self._mock_state_holder.mock_calls, [call.sync(), call.mark(), call.sync()] 605 ) 606 607 def test_run_passes_expected_context_and_deadline_to_state_handler(self) -> None: 608 settings = self._create_settings() 609 610 op_executor = self._create_op_executor(settings) 611 612 op = MagicMock(return_value=_Action.FINISH) 613 614 op_executor.run(op, deadline=3) 615 616 ctx, deadline = op.call_args[0] # args 617 618 self.assertIs(ctx.node, self._node) 619 self.assertIs(ctx.state, self._state) 620 self.assertIs(ctx.settings, settings) 621 622 self.assertEqual(deadline, 3) 623 624 def test_run_keeps_alive(self) -> None: 625 expected_state = _RendezvousState() 626 627 expected_state.last_heartbeats[self._node] = self._now 628 629 self._assert_action(_Action.KEEP_ALIVE, expected_state) 630 631 def test_run_adds_to_participants(self) -> None: 632 expected_state = _RendezvousState() 633 634 expected_state.participants[self._node] = 0 635 636 expected_state.last_heartbeats[self._node] = self._now 637 638 self._min_nodes = 2 639 self._max_nodes = 2 640 641 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) 642 643 def test_run_adds_to_participants_if_node_was_in_waitlist(self) -> None: 644 self._state.wait_list.add(self._node) 645 646 expected_state = _RendezvousState() 647 648 expected_state.participants[self._node] = 0 649 650 expected_state.last_heartbeats[self._node] = self._now 651 652 self._min_nodes = 2 653 self._max_nodes = 2 654 655 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) 656 657 def _add_participants( 658 self, num_participants: int, state: _RendezvousState, ranked: bool = False 659 ) -> None: 660 for i in range(num_participants): 661 if ranked: 662 node = _NodeDesc(f"dummy{i}", 1, 1) 663 rank = i 664 else: 665 node = _NodeDesc( 666 f"dummy{num_participants - i - 1}", 1, 1 667 ) # Add in reverse. 668 rank = 0 669 670 state.participants[node] = rank 671 672 state.last_heartbeats[node] = self._now 673 674 def test_run_adds_to_participants_and_starts_last_call_if_min_nodes_is_reached( 675 self, 676 ) -> None: 677 for num_participants in range(3): 678 self._state = _RendezvousState() 679 680 self._add_participants(num_participants, self._state) 681 682 self._state.wait_list.add(self._node) 683 684 expected_state = _RendezvousState() 685 686 self._add_participants(num_participants, expected_state) 687 688 expected_state.participants[self._node] = 0 689 690 expected_state.last_heartbeats[self._node] = self._now 691 692 expected_state.deadline = self._now + self._timeout.last_call 693 694 with self.subTest(num_participants=num_participants): 695 self._min_nodes = num_participants + 1 696 self._max_nodes = num_participants + 2 697 698 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) 699 700 self._mock_state_holder.reset_mock() 701 702 def test_run_adds_to_participants_and_completes_rendezvous_if_max_nodes_is_reached( 703 self, 704 ) -> None: 705 for min_max_nodes_equal in [False, True]: 706 for num_participants in range(3): 707 rank = num_participants 708 709 self._state = _RendezvousState() 710 711 self._add_participants(num_participants, self._state) 712 713 self._state.wait_list.add(self._node) 714 715 self._state.deadline = self._now + self._timeout.last_call 716 717 expected_state = _RendezvousState() 718 719 self._add_participants(num_participants, expected_state, ranked=True) 720 721 expected_state.participants[self._node] = rank 722 723 expected_state.last_heartbeats[self._node] = self._now 724 725 expected_state.complete = True 726 expected_state.deadline = None 727 728 with self.subTest(num_participants=num_participants): 729 self._min_nodes = num_participants + 1 if min_max_nodes_equal else 0 730 self._max_nodes = num_participants + 1 731 732 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) 733 734 self._mock_state_holder.reset_mock() 735 736 def test_run_adds_to_waitlist(self) -> None: 737 expected_state = _RendezvousState() 738 739 expected_state.wait_list.add(self._node) 740 741 expected_state.last_heartbeats[self._node] = self._now 742 743 self._assert_action(_Action.ADD_TO_WAIT_LIST, expected_state) 744 745 def test_run_removes_from_participants(self) -> None: 746 for complete, last_call_deadline in [(False, self._now), (True, None)]: 747 self._state = _RendezvousState() 748 749 self._add_participants(2, self._state) 750 751 self._state.participants[self._node] = 0 752 753 self._state.last_heartbeats[self._node] = self._now 754 755 self._state.complete = complete 756 self._state.deadline = last_call_deadline 757 758 self._state.round = 1 759 760 expected_state = _RendezvousState() 761 762 self._add_participants(2, expected_state) 763 764 expected_state.complete = complete 765 expected_state.deadline = last_call_deadline 766 767 expected_state.round = 1 768 769 with self.subTest(complete=complete): 770 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state) 771 772 self._mock_state_holder.reset_mock() 773 774 def test_run_removes_from_participants_and_moves_to_next_round_if_node_is_last_participant( 775 self, 776 ) -> None: 777 self._state.participants[self._node] = 0 778 779 self._state.last_heartbeats[self._node] = self._now 780 781 self._state.complete = True 782 783 self._state.round = 1 784 785 expected_state = _RendezvousState() 786 787 expected_state.complete = False 788 789 expected_state.round = 2 790 791 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state) 792 793 def test_run_removes_from_participants_and_clears_last_call_if_rendezvous_has_less_than_min_nodes( 794 self, 795 ) -> None: 796 self._add_participants(2, self._state) 797 798 self._state.participants[self._node] = 0 799 800 self._state.last_heartbeats[self._node] = self._now 801 802 self._state.deadline = self._now 803 804 expected_state = _RendezvousState() 805 806 self._add_participants(2, expected_state) 807 808 self._min_nodes = 3 809 self._max_nodes = 4 810 811 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state) 812 813 def test_run_removes_from_waitlist(self) -> None: 814 self._state.wait_list.add(self._node) 815 816 self._state.last_heartbeats[self._node] = self._now 817 818 expected_state = _RendezvousState() 819 820 self._assert_action(_Action.REMOVE_FROM_WAIT_LIST, expected_state) 821 822 def test_run_marks_rendezvous_closed(self) -> None: 823 expected_state = _RendezvousState() 824 825 expected_state.closed = True 826 827 self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED, expected_state) 828 829 def test_run_raises_error_if_rendezvous_is_closed(self) -> None: 830 with self.assertRaises(RendezvousClosedError): 831 self._run_action(_Action.ERROR_CLOSED) 832 833 self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()]) 834 835 def test_run_raises_error_if_operation_timed_out(self) -> None: 836 with self.assertRaises(RendezvousTimeoutError): 837 self._run_action(_Action.ERROR_TIMEOUT) 838 839 self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()]) 840 841 def test_run_delays_execution_if_sync_requested(self) -> None: 842 with patch( 843 "torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay" 844 ) as mock_delay: 845 self._run_action(_Action.SYNC) 846 847 mock_delay.assert_called_once_with(seconds=1) 848 849 self.assertListEqual( 850 self._mock_state_holder.mock_calls, [call.sync(), call.sync()] 851 ) 852 853 854class AbstractTestRendezvousOp(ABC): 855 assertEqual: Callable 856 857 def setUp(self) -> None: 858 self._node = _NodeDesc("this_node", 1, 1) 859 860 self._min_nodes = 1 861 self._max_nodes = 2 862 863 self._keep_alive_interval = timedelta(seconds=30) 864 865 self._state = _RendezvousState() 866 self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1 867 868 self._now = datetime(2000, 1, 1, hour=0, minute=0) 869 870 self._deadline = 10 871 872 self._datetime_patch = patch( 873 "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime" 874 ) 875 876 mock_datetime = self._datetime_patch.start() 877 mock_datetime.utcnow.return_value = self._now 878 879 self._time_patch = patch( 880 "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time" 881 ) 882 883 mock_time = self._time_patch.start() 884 mock_time.monotonic.return_value = self._deadline 885 886 def tearDown(self) -> None: 887 self._time_patch.stop() 888 self._datetime_patch.stop() 889 890 def _get_next_action(self) -> _Action: 891 op = self._create_op() 892 893 settings = RendezvousSettings( 894 run_id="dummy_run_id", 895 min_nodes=self._min_nodes, 896 max_nodes=self._max_nodes, 897 timeout=RendezvousTimeout(), 898 keep_alive_interval=self._keep_alive_interval, 899 keep_alive_max_attempt=3, 900 ) 901 902 ctx = _RendezvousContext(self._node, self._state, settings) 903 904 return op(ctx, self._deadline) 905 906 @abstractmethod 907 def _create_op(self) -> Callable: 908 pass 909 910 def _assert_action(self, expected_action) -> None: 911 action = self._get_next_action() 912 913 self.assertEqual(action, expected_action) 914 915 916class TestRendezvousExitOp(AbstractTestRendezvousOp, TestCase): 917 def _create_op(self) -> Callable: 918 return _RendezvousExitOp() 919 920 def test_removes_from_participants_if_node_is_participant(self) -> None: 921 self._state.participants[self._node] = 1 922 923 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS) 924 925 def test_raises_timeout_if_deadline_exceeded(self) -> None: 926 self._deadline = 0 927 928 self._state.participants[self._node] = 1 929 930 self._assert_action(_Action.ERROR_TIMEOUT) 931 932 def test_finishes_if_node_is_not_participant(self) -> None: 933 self._assert_action(_Action.FINISH) 934 935 936class TestRendezvousJoinOp(AbstractTestRendezvousOp, TestCase): 937 def _create_op(self) -> Callable: 938 return _RendezvousJoinOp() 939 940 def test_raises_closed_if_rendezvous_is_closed(self) -> None: 941 self._state.closed = True 942 943 self._assert_action(_Action.ERROR_CLOSED) 944 945 def test_finishes_if_rendezvous_is_complete_and_node_is_participant(self) -> None: 946 self._state.participants[self._node] = 0 947 948 self._state.complete = True 949 950 self._assert_action(_Action.FINISH) 951 952 def _assert_waits_rendezvous_completion(self) -> None: 953 keep_alive_time = self._now - self._keep_alive_interval 954 955 for delta, expected_action in [ 956 (timedelta(seconds=0), _Action.KEEP_ALIVE), 957 (timedelta(seconds=1), _Action.SYNC), 958 ]: 959 self._state.last_heartbeats[self._node] = keep_alive_time + delta 960 961 self._assert_action(expected_action) 962 963 def test_treat_as_redundancy_for_next_rendezvous_if_rendezvous_is_complete( 964 self, 965 ) -> None: 966 self._max_nodes = 1 967 968 self._state.complete = True 969 970 self._assert_action(_Action.ADD_TO_REDUNDANCY_LIST) 971 972 def test_waits_next_round_if_rendezvous_is_complete_and_node_is_redundant( 973 self, 974 ) -> None: 975 self._state.redundancy_list.add(self._node) 976 977 self._max_nodes = 1 978 979 self._state.complete = True 980 981 self._assert_waits_rendezvous_completion() 982 983 def test_remove_from_rednundancy_list(self) -> None: 984 self._state.redundancy_list.add(self._node) 985 986 self._max_nodes = 2 987 988 self._state.complete = True 989 990 self._assert_action(_Action.REMOVE_FROM_REDUNDANCY_LIST) 991 992 def test_waits_next_round_if_rendezvous_is_complete_and_node_is_in_wait_list( 993 self, 994 ) -> None: 995 self._state.wait_list.add(self._node) 996 997 self._state.complete = True 998 999 self._assert_waits_rendezvous_completion() 1000 1001 def test_adds_to_wait_list_if_rendezvous_is_complete_and_num_nodes_is_less_than_max_nodes( 1002 self, 1003 ) -> None: 1004 self._state.complete = True 1005 1006 self._assert_action(_Action.ADD_TO_WAIT_LIST) 1007 1008 def test_waits_rendezvous_to_complete_if_node_is_participant(self) -> None: 1009 self._max_nodes = 3 1010 1011 self._state.participants[self._node] = 0 1012 1013 self._state.deadline = self._now 1014 1015 self._assert_waits_rendezvous_completion() 1016 1017 def test_marks_rendezvous_complete_if_node_is_participant_and_last_call_deadline_exceeded( 1018 self, 1019 ) -> None: 1020 self._max_nodes = 3 1021 1022 self._state.participants[self._node] = 0 1023 1024 self._state.deadline = self._now - timedelta(seconds=1) 1025 1026 self._assert_action(_Action.MARK_RENDEZVOUS_COMPLETE) 1027 1028 def test_adds_to_participants(self) -> None: 1029 self._assert_action(_Action.ADD_TO_PARTICIPANTS) 1030 1031 def test_raises_timeout_if_deadline_exceeded(self) -> None: 1032 self._deadline = 0 1033 1034 self._assert_action(_Action.ERROR_TIMEOUT) 1035 1036 def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_participant( 1037 self, 1038 ) -> None: 1039 self._deadline = 0 1040 1041 self._state.participants[self._node] = 0 1042 1043 self._assert_action(_Action.ERROR_TIMEOUT) 1044 1045 def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_in_wait_list( 1046 self, 1047 ) -> None: 1048 self._deadline = 0 1049 1050 self._state.wait_list.add(self._node) 1051 1052 self._assert_action(_Action.ERROR_TIMEOUT) 1053 1054 def test_removes_from_participants_if_timed_out_but_rollback_deadline_is_not_reached( 1055 self, 1056 ) -> None: 1057 self._deadline = 5 1058 1059 self._state.participants[self._node] = 0 1060 1061 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS) 1062 1063 def test_removes_from_wait_list_if_timed_out_but_rollback_deadline_is_not_reached( 1064 self, 1065 ) -> None: 1066 self._deadline = 5 1067 1068 self._state.wait_list.add(self._node) 1069 1070 self._assert_action(_Action.REMOVE_FROM_WAIT_LIST) 1071 1072 def test_no_timeout_for_redundant_node(self) -> None: 1073 self._max_nodes = 1 1074 self._deadline = 0 1075 self._state.complete = True 1076 1077 self._state.redundancy_list.add(self._node) 1078 1079 self._assert_action(_Action.SYNC) 1080 1081 def test_keep_alive_for_redundant_node(self) -> None: 1082 self._deadline = 0 1083 self._max_nodes = 1 1084 self._state.complete = True 1085 1086 self._state.redundancy_list.add(self._node) 1087 1088 keep_alive_time = self._now - self._keep_alive_interval 1089 self._state.last_heartbeats[self._node] = keep_alive_time 1090 self._assert_action(_Action.KEEP_ALIVE) 1091 1092 1093class TestRendezvousCloseOp(AbstractTestRendezvousOp, TestCase): 1094 def _create_op(self) -> Callable: 1095 return _RendezvousCloseOp() 1096 1097 def test_finishes_if_rendezvous_is_closed(self) -> None: 1098 self._state.closed = True 1099 1100 self._assert_action(_Action.FINISH) 1101 1102 def test_raises_timeout_if_deadline_exceeded(self) -> None: 1103 self._deadline = 0 1104 1105 self._assert_action(_Action.ERROR_TIMEOUT) 1106 1107 def test_marks_rendezvous_closed(self) -> None: 1108 self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED) 1109 1110 1111class TestRendezvousKeepAliveOp(AbstractTestRendezvousOp, TestCase): 1112 def _create_op(self) -> Callable: 1113 return _RendezvousKeepAliveOp() 1114 1115 def test_updates_keep_alive_if_needed(self) -> None: 1116 keep_alive_time = self._now - self._keep_alive_interval 1117 1118 for delta in [timedelta(seconds=0), timedelta(seconds=-1)]: 1119 with self.subTest(delta=delta): 1120 self._state.last_heartbeats[self._node] = keep_alive_time + delta 1121 1122 self._assert_action(_Action.KEEP_ALIVE) 1123 1124 def test_raises_timeout_if_deadlined_exceeded(self) -> None: 1125 self._deadline = 0 1126 1127 self._state.last_heartbeats[self._node] = self._now - self._keep_alive_interval 1128 1129 self._assert_action(_Action.ERROR_TIMEOUT) 1130 1131 def test_finishes_if_no_keep_alive_update_is_needed(self) -> None: 1132 delta = timedelta(seconds=1) 1133 1134 self._state.last_heartbeats[self._node] = ( 1135 self._now - self._keep_alive_interval + delta 1136 ) 1137 1138 self._assert_action(_Action.FINISH) 1139 1140 1141class DummyStore(Store): 1142 pass 1143 1144 1145class DynamicRendezvousHandlerTest(TestCase): 1146 def setUp(self) -> None: 1147 self._node = _NodeDesc("this_node", 1, 1) 1148 1149 self._min_nodes = 1 1150 self._max_nodes = 1 1151 1152 self._join_timeout: Optional[timedelta] = None 1153 self._close_timeout: Optional[timedelta] = None 1154 self._heartbeat_timeout: Optional[timedelta] = None 1155 1156 self._keep_alive_interval = timedelta(seconds=30) 1157 1158 self._store = DummyStore() 1159 1160 self._mock_store_get = MagicMock(return_value=b"123") 1161 self._mock_store_set = MagicMock() 1162 1163 setattr(self._store, "get", self._mock_store_get) # noqa: B010 1164 setattr(self._store, "set", self._mock_store_set) # noqa: B010 1165 1166 self._state_holder = FakeRendezvousStateHolder() 1167 1168 self._mock_sync = MagicMock(wraps=self._state_holder.sync) 1169 1170 setattr(self._state_holder, "sync", self._mock_sync) # noqa: B010 1171 1172 self._state = self._state_holder.state 1173 1174 self._tcp_store_mock = DummyStore() 1175 1176 patcher = patch.object( 1177 DynamicRendezvousHandler, 1178 "_create_tcp_store_server", 1179 return_value=self._tcp_store_mock, 1180 ) 1181 patcher.start() 1182 self.addCleanup(patcher.stop) 1183 1184 def _create_handler(self) -> DynamicRendezvousHandler: 1185 settings = RendezvousSettings( 1186 run_id="dummy_run_id", 1187 min_nodes=self._min_nodes, 1188 max_nodes=self._max_nodes, 1189 timeout=RendezvousTimeout( 1190 join=self._join_timeout, 1191 close=self._close_timeout, 1192 heartbeat=self._heartbeat_timeout, 1193 ), 1194 keep_alive_interval=self._keep_alive_interval, 1195 keep_alive_max_attempt=3, 1196 ) 1197 1198 self._state_holder.state = self._state 1199 1200 return DynamicRendezvousHandler( 1201 self._node, settings, "dummy_backend", self._store, self._state_holder 1202 ) 1203 1204 def test_share_store_creates_tcp_store(self): 1205 handler = self._create_handler() 1206 1207 shared_store_info = RendezvousStoreInfo("host", 54321) 1208 with patch.object(RendezvousStoreInfo, "build", return_value=shared_store_info): 1209 rdzv_info = handler.next_rendezvous() 1210 self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host") 1211 self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321) 1212 self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock) 1213 1214 rdzv_info = handler.next_rendezvous() 1215 self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock) 1216 1217 def test_share_store_when_tcp_store(self): 1218 handler = self._create_handler() 1219 1220 with patch.object(dist, "PrefixStore", new=Mock): 1221 handler._store = Mock(spec=dist.TCPStore) 1222 type(handler._store).host = PropertyMock(return_value="host") 1223 type(handler._store).port = PropertyMock(return_value=54321) 1224 rdzv_info = handler.next_rendezvous() 1225 self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host") 1226 self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321) 1227 self.assertEqual(handler._shared_tcp_store_server, handler._store) 1228 1229 rdzv_info = handler.next_rendezvous() 1230 self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host") 1231 self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321) 1232 self.assertEqual(handler._shared_tcp_store_server, handler._store) 1233 1234 @patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay") 1235 def test_next_rendezvous_skews_the_first_join_attempt(self, mock_delay) -> None: 1236 for round, expected_call_count in [(0, True), (1, False)]: 1237 with self.subTest(round=round): 1238 self._state.round = round 1239 1240 handler = self._create_handler() 1241 1242 handler.next_rendezvous() 1243 1244 self.assertEqual(mock_delay.call_count, expected_call_count) 1245 1246 mock_delay.reset_mock() 1247 1248 def test_next_rendezvous_returns_expected_value(self) -> None: 1249 self._state.participants[_NodeDesc("dummy1", 1, 1)] = 0 1250 self._state.participants[_NodeDesc("dummy2", 1, 1)] = 0 1251 1252 self._max_nodes = 3 1253 1254 handler = self._create_handler() 1255 1256 rdzv_info = handler.next_rendezvous() 1257 1258 self.assertEqual(rdzv_info.rank, 2) 1259 self.assertEqual(rdzv_info.world_size, 3) 1260 1261 _ = rdzv_info.store.get("dummy_key") 1262 1263 self._mock_store_get.assert_called_with( 1264 "torch.rendezvous.dummy_run_id.0/dummy_key" 1265 ) 1266 1267 def test_next_rendezvous_respects_the_requested_timeout(self) -> None: 1268 self._mock_sync.side_effect = lambda: time.sleep(0.3) 1269 1270 self._join_timeout = timedelta(seconds=0.2) 1271 1272 handler = self._create_handler() 1273 1274 with self.assertRaises(RendezvousTimeoutError): 1275 handler.next_rendezvous() 1276 1277 def test_next_rendezvous_moves_to_next_round_if_called_repeatedly(self) -> None: 1278 handler = self._create_handler() 1279 1280 for i in range(4): 1281 handler.next_rendezvous() 1282 1283 self.assertEqual(self._state.round, i) 1284 1285 def test_is_closed_returns_expected_value(self) -> None: 1286 for closed in [False, True]: 1287 with self.subTest(closed=closed): 1288 self._state.closed = closed 1289 1290 handler = self._create_handler() 1291 1292 self.assertEqual(handler.is_closed(), closed) 1293 1294 self._mock_sync.assert_called_once() 1295 1296 self._mock_sync.reset_mock() 1297 1298 @patch("torch.distributed.elastic.events.record_rdzv_event") 1299 def test_is_closed_records_and_raises_exceptions(self, record_mock) -> None: 1300 self._mock_sync.side_effect = RendezvousError("test error") 1301 handler = self._create_handler() 1302 with self.assertRaises(RendezvousError): 1303 handler.is_closed() 1304 record_mock.assert_called_once() 1305 1306 def test_set_closed_closes_rendezvous(self) -> None: 1307 handler = self._create_handler() 1308 1309 handler.set_closed() 1310 1311 self.assertTrue(self._state.closed) 1312 1313 def test_set_closed_respects_the_requested_timeout(self) -> None: 1314 self._mock_sync.side_effect = lambda: time.sleep(0.3) 1315 1316 self._close_timeout = timedelta(seconds=0.2) 1317 1318 handler = self._create_handler() 1319 1320 with self.assertRaises(RendezvousTimeoutError): 1321 handler.set_closed() 1322 1323 def test_set_closed_can_be_called_multiple_times(self) -> None: 1324 handler = self._create_handler() 1325 1326 handler.set_closed() 1327 handler.set_closed() 1328 1329 self.assertTrue(self._state.closed) 1330 1331 @patch("torch.distributed.elastic.events.record_rdzv_event") 1332 def test_set_closed_records_and_raises_exceptions(self, record_mock) -> None: 1333 with patch.object(DynamicRendezvousHandler, "_close") as close_mock: 1334 close_mock.side_effect = RendezvousError("test error") 1335 handler = self._create_handler() 1336 with self.assertRaises(RendezvousError): 1337 handler.set_closed() 1338 record_mock.assert_called_once() 1339 1340 def test_num_nodes_waiting_returns_expected_value(self) -> None: 1341 self._state.wait_list.add(_NodeDesc("dummy1", 1, 1)) 1342 self._state.wait_list.add(_NodeDesc("dummy2", 1, 1)) 1343 1344 handler = self._create_handler() 1345 1346 self.assertEqual(handler.num_nodes_waiting(), 2) 1347 1348 self._mock_sync.assert_called_once() 1349 1350 @patch("torch.distributed.elastic.events.record_rdzv_event") 1351 def test_num_nodes_waiting_records_and_raises_exceptions(self, record_mock) -> None: 1352 self._mock_sync.side_effect = RendezvousError("test error") 1353 handler = self._create_handler() 1354 with self.assertRaises(RendezvousError): 1355 handler.num_nodes_waiting() 1356 record_mock.assert_called_once() 1357 1358 def test_shutdown_closes_rendezvous_and_returns_true(self) -> None: 1359 handler = self._create_handler() 1360 1361 result = handler.shutdown() 1362 1363 self.assertTrue(result) 1364 1365 self.assertTrue(self._state.closed) 1366 1367 def test_shutdown_returns_false_if_rendezvous_cannot_be_closed(self) -> None: 1368 self._mock_sync.side_effect = [RendezvousError] 1369 1370 handler = self._create_handler() 1371 1372 result = handler.shutdown() 1373 1374 self.assertFalse(result) 1375 1376 def test_shutdown_can_be_called_multiple_times(self) -> None: 1377 handler = self._create_handler() 1378 1379 handler.shutdown() 1380 handler.shutdown() 1381 1382 self.assertTrue(self._state.closed) 1383 1384 @patch("torch.distributed.elastic.events.record_rdzv_event") 1385 def test_shutdown_records_and_raises_exceptions(self, record_mock) -> None: 1386 with patch.object(DynamicRendezvousHandler, "_close") as close_mock: 1387 close_mock.side_effect = RuntimeError("test error") 1388 handler = self._create_handler() 1389 with self.assertRaises(RuntimeError): 1390 handler.shutdown() 1391 record_mock.assert_called_once() 1392 1393 @patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime") 1394 def test_keep_alive_updates_last_heartbeat(self, mock_datetime) -> None: 1395 now = datetime(2000, 1, 1, hour=0, minute=0) 1396 1397 mock_datetime.utcnow.return_value = now 1398 1399 self._state.last_heartbeats[self._node] = now - (self._keep_alive_interval * 2) 1400 1401 handler = self._create_handler() 1402 1403 handler._keep_alive() 1404 1405 self.assertEqual(self._state.last_heartbeats[self._node], now) 1406 1407 def _assert_keep_alive_swallows_rendezvous_errors(self) -> None: 1408 last_heartbeat_time = datetime.utcnow() - (self._keep_alive_interval * 2) 1409 1410 self._state.last_heartbeats[self._node] = last_heartbeat_time 1411 1412 handler = self._create_handler() 1413 1414 handler._keep_alive() 1415 1416 self.assertEqual(self._state.last_heartbeats[self._node], last_heartbeat_time) 1417 1418 def test_keep_alive_swallows_rendezvous_errors(self) -> None: 1419 self._mock_sync.side_effect = [RendezvousError] 1420 1421 self._assert_keep_alive_swallows_rendezvous_errors() 1422 1423 def test_keep_alive_respects_the_requested_timeout(self) -> None: 1424 self._mock_sync.side_effect = lambda: time.sleep(0.3) 1425 1426 self._heartbeat_timeout = timedelta(seconds=0.2) 1427 1428 self._assert_keep_alive_swallows_rendezvous_errors() 1429 1430 def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_shutdown( 1431 self, 1432 ) -> None: 1433 self._node = _NodeDesc("this_node", 1, 2) 1434 1435 name = "RendezvousKeepAliveTimer_2" 1436 1437 handler = self._create_handler() 1438 1439 self.assertTrue(all(t.name != name for t in threading.enumerate())) 1440 1441 handler.next_rendezvous() 1442 1443 self.assertTrue(any(t.name == name for t in threading.enumerate())) 1444 1445 handler.shutdown() 1446 1447 self.assertTrue(all(t.name != name for t in threading.enumerate())) 1448 1449 def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_finalizer( 1450 self, 1451 ) -> None: 1452 self._node = _NodeDesc("this_node", 1, 3) 1453 1454 name = "RendezvousKeepAliveTimer_3" 1455 1456 handler = self._create_handler() 1457 1458 self.assertTrue(all(t.name != name for t in threading.enumerate())) 1459 1460 handler.next_rendezvous() 1461 1462 self.assertTrue(any(t.name == name for t in threading.enumerate())) 1463 1464 del handler 1465 1466 self.assertTrue(all(t.name != name for t in threading.enumerate())) 1467 1468 1469class DummyRendezvousBackend(RendezvousBackend): 1470 @property 1471 def name(self): 1472 return "dummy_backend" 1473 1474 def get_state(self): 1475 return None 1476 1477 def set_state(self, state, token): 1478 return None 1479 1480 1481class DynamicRendezvousHandlerFromBackendTest(TestCase): 1482 def setUp(self) -> None: 1483 self._run_id = "dummy_run_id" 1484 self._store = DummyStore() 1485 self._backend = DummyRendezvousBackend() 1486 self._min_nodes = 3 1487 self._max_nodes = 6 1488 self._timeout: Optional[RendezvousTimeout] = RendezvousTimeout() 1489 1490 def _create_handler(self) -> DynamicRendezvousHandler: 1491 return DynamicRendezvousHandler.from_backend( 1492 run_id=self._run_id, 1493 store=self._store, 1494 backend=self._backend, 1495 min_nodes=self._min_nodes, 1496 max_nodes=self._max_nodes, 1497 timeout=self._timeout, 1498 ) 1499 1500 def test_init_initializes_handler(self) -> None: 1501 handler = self._create_handler() 1502 1503 self.assertEqual(handler.get_backend(), self._backend.name) 1504 self.assertEqual(handler.get_run_id(), self._run_id) 1505 self.assertEqual(handler.settings.run_id, self._run_id) 1506 self.assertEqual(handler.settings.min_nodes, self._min_nodes) 1507 self.assertEqual(handler.settings.max_nodes, self._max_nodes) 1508 1509 if self._timeout is None: 1510 self.assertIsNotNone(handler.settings.timeout) 1511 else: 1512 self.assertIs(handler.settings.timeout, self._timeout) 1513 1514 def test_init_initializes_handler_if_timeout_is_not_specified(self) -> None: 1515 self._timeout = None 1516 1517 self.test_init_initializes_handler() 1518 1519 def test_init_initializes_handler_if_min_and_max_nodes_are_equal(self) -> None: 1520 self._min_nodes = 3 1521 self._max_nodes = 3 1522 1523 self.test_init_initializes_handler() 1524 1525 def test_init_raises_error_if_min_nodes_is_not_positive(self) -> None: 1526 for num in [0, -10]: 1527 with self.subTest(min_nodes=num): 1528 self._min_nodes = num 1529 1530 with self.assertRaisesRegex( 1531 ValueError, 1532 rf"^The minimum number of nodes \({num}\) must be greater than zero.$", 1533 ): 1534 self._create_handler() 1535 1536 def test_init_raises_error_if_max_nodes_is_less_than_min(self) -> None: 1537 self._min_nodes = 3 1538 self._max_nodes = 2 1539 1540 with self.assertRaisesRegex( 1541 ValueError, 1542 rf"^The maximum number of nodes \({self._max_nodes}\) must be greater than or equal to " 1543 "the minimum number of nodes " 1544 rf"\({self._min_nodes}\).$", 1545 ): 1546 self._create_handler() 1547 1548 1549class CreateHandlerTest(TestCase): 1550 def setUp(self) -> None: 1551 self._store = DummyStore() 1552 1553 self._backend = DummyRendezvousBackend() 1554 1555 self._params = RendezvousParameters( 1556 backend=self._backend.name, 1557 endpoint="dummy_endpoint", 1558 run_id="dummy_run_id", 1559 min_nodes=3, 1560 max_nodes=6, 1561 join_timeout="50", 1562 last_call_timeout="60", 1563 close_timeout="70", 1564 ) 1565 1566 self._expected_timeout = RendezvousTimeout( 1567 timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70) 1568 ) 1569 1570 def test_create_handler_returns_handler(self) -> None: 1571 handler = create_handler(self._store, self._backend, self._params) 1572 1573 self.assertEqual(handler.get_backend(), self._backend.name) 1574 self.assertEqual(handler.get_run_id(), self._params.run_id) 1575 self.assertEqual(handler.settings.min_nodes, self._params.min_nodes) 1576 self.assertEqual(handler.settings.max_nodes, self._params.max_nodes) 1577 self.assertEqual(handler.settings.timeout.join, self._expected_timeout.join) 1578 self.assertEqual( 1579 handler.settings.timeout.last_call, self._expected_timeout.last_call 1580 ) 1581 self.assertEqual(handler.settings.timeout.close, self._expected_timeout.close) 1582 1583 def test_create_handler_returns_handler_if_timeout_is_not_specified(self) -> None: 1584 del self._params.config["join_timeout"] 1585 del self._params.config["last_call_timeout"] 1586 del self._params.config["close_timeout"] 1587 1588 self._expected_timeout = RendezvousTimeout() 1589 1590 self.test_create_handler_returns_handler() 1591 1592 @patch("torch.distributed.elastic.events.record_rdzv_event") 1593 def test_create_handler_records_and_raises_exceptions(self, record_mock) -> None: 1594 with patch.object(DynamicRendezvousHandler, "from_backend") as from_mock: 1595 from_mock.side_effect = RendezvousError("test error") 1596 with self.assertRaises(RendezvousError): 1597 create_handler(self._store, self._backend, self._params) 1598 record_mock.assert_called_once() 1599 1600 def test_create_handler_rdzv_local_addr(self) -> None: 1601 params = RendezvousParameters( 1602 backend=self._backend.name, 1603 endpoint="dummy_endpoint", 1604 run_id="dummy_run_id", 1605 min_nodes=1, 1606 max_nodes=1, 1607 join_timeout="50", 1608 last_call_timeout="60", 1609 close_timeout="70", 1610 local_addr="127.0.0.2", 1611 ) 1612 store = HashStore() 1613 handler = create_handler(store, self._backend, params) 1614 rdzv_info = handler.next_rendezvous() 1615 self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "127.0.0.2") 1616 1617 1618def _ignore_exception(exception_type: Exception, fn: Callable): 1619 try: 1620 fn() 1621 except exception_type as e: 1622 pass 1623 1624 1625def _wait_for(condition, timeout=10, interval=1, name=None): 1626 def _wait_while(): 1627 while True: 1628 if condition(): 1629 break 1630 else: 1631 time.sleep(interval) 1632 1633 wait_thread = threading.Thread(target=_wait_while, name=name) 1634 wait_thread.start() 1635 wait_thread.join(timeout=timeout) 1636 1637 1638class _CapturingThread(threading.Thread): 1639 def __init__(self, target=None, name=None, args=None, kwargs=None): 1640 if args is None: 1641 args = () 1642 if kwargs is None: 1643 kwargs = {} 1644 threading.Thread.__init__( 1645 self, target=target, args=args, kwargs=kwargs, name=name 1646 ) 1647 self._result = None 1648 1649 def run(self): 1650 if self._target is not None: 1651 self._result = self._target(*self._args, **self._kwargs) 1652 1653 def join(self, *args): 1654 threading.Thread.join(self, *args) 1655 return self._result 1656 1657 1658class IntegrationTest(TestCase): 1659 def setUp(self) -> None: 1660 self._store = HashStore() 1661 self._handlers = [] 1662 self._backend = _InMemoryRendezvousBackend() 1663 1664 def tearDown(self) -> None: 1665 for handler in self._handlers: 1666 handler._stop_heartbeats() 1667 1668 def _create_handler(self, **kwargs) -> DynamicRendezvousHandler: 1669 params = { 1670 "backend": self._backend.name, 1671 "endpoint": "dummy_endpoint", 1672 "run_id": "dummy_run_id", 1673 "min_nodes": 2, 1674 "max_nodes": 2, 1675 "join_timeout": "5", 1676 "local_addr": f"127.0.0.{len(self._handlers)}", 1677 } 1678 params.update(**kwargs) 1679 1680 rzdv_params = RendezvousParameters(**params) 1681 1682 handler = create_handler(self._store, self._backend, rzdv_params) 1683 self._handlers.append(handler) 1684 return handler 1685 1686 def test_all_nodes_join_rendezvous(self) -> None: 1687 handler1 = self._create_handler(min_nodes=2, max_nodes=2) 1688 handler2 = self._create_handler(min_nodes=2, max_nodes=2) 1689 1690 handler1_thread = _CapturingThread(target=handler1.next_rendezvous) 1691 handler2_thread = _CapturingThread(target=handler2.next_rendezvous) 1692 1693 handler1_thread.start() 1694 handler2_thread.start() 1695 1696 rdzv_info1: RendezvousInfo = handler1_thread.join() 1697 rdzv_info2: RendezvousInfo = handler2_thread.join() 1698 self.assertEqual(rdzv_info1.store.underlying_store, self._store) 1699 self.assertEqual(rdzv_info2.store.underlying_store, self._store) 1700 1701 self.assertNotEqual(rdzv_info1.rank, rdzv_info2.rank) 1702 1703 self.assertEqual(rdzv_info1.world_size, 2) 1704 self.assertEqual(rdzv_info2.world_size, 2) 1705 1706 def test_redundancy_list(self) -> None: 1707 handler1 = self._create_handler(min_nodes=2, max_nodes=2) 1708 handler2 = self._create_handler(min_nodes=2, max_nodes=2) 1709 handler3 = self._create_handler(min_nodes=2, max_nodes=2) 1710 1711 handler1_thread = _CapturingThread(target=handler1.next_rendezvous) 1712 handler2_thread = _CapturingThread(target=handler2.next_rendezvous) 1713 handler3_thread = _CapturingThread( 1714 target=_ignore_exception, 1715 args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()), 1716 ) 1717 1718 handler1_thread.start() 1719 handler2_thread.start() 1720 1721 # establish successful rendezvous 1722 handler1_thread.join() 1723 handler2_thread.join() 1724 1725 # expect to register in redundancy list 1726 handler3_thread.start() 1727 1728 # wait until the handler3 is registered in the redundancy list 1729 _wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list) 1730 1731 state_and_token = self._backend.get_state() 1732 state = pickle.loads(state_and_token[0]) 1733 addresses = [node.addr for node in state.redundancy_list] 1734 self.assertListEqual(addresses, ["127.0.0.2"]) 1735 1736 def test_redundancy_transition_to_wait_list_then_join_rendezvous(self) -> None: 1737 handler1 = self._create_handler( 1738 min_nodes=1, 1739 max_nodes=2, 1740 ) 1741 handler2 = self._create_handler( 1742 min_nodes=1, 1743 max_nodes=2, 1744 keep_alive_interval=timedelta(seconds=1), 1745 ) 1746 handler3 = self._create_handler( 1747 min_nodes=1, 1748 max_nodes=2, 1749 ) 1750 1751 handler1_thread = _CapturingThread(target=handler1.next_rendezvous) 1752 handler2_thread = _CapturingThread(target=handler2.next_rendezvous) 1753 1754 handler3_thread = _CapturingThread( 1755 target=_ignore_exception, 1756 args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()), 1757 ) 1758 1759 handler1_thread.start() 1760 handler2_thread.start() 1761 1762 # establish successful rendezvous 1763 handler1_thread.join() 1764 handler2_thread.join() 1765 1766 handler3_thread.start() 1767 1768 _wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list) 1769 1770 handler2._stop_heartbeats() 1771 1772 _wait_for( 1773 lambda: len(pickle.loads(self._backend.get_state()[0]).participants) == 1 1774 ) 1775 _wait_for( 1776 lambda: len(pickle.loads(self._backend.get_state()[0]).wait_list) == 1 1777 ) 1778 1779 def test_use_agent_store_is_true_by_default(self): 1780 handler = self._create_handler( 1781 min_nodes=1, 1782 max_nodes=2, 1783 ) 1784 1785 self.assertTrue(handler.use_agent_store) 1786 1787 @patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"}) 1788 def test_use_agent_store_is_disabled(self): 1789 handler = self._create_handler( 1790 min_nodes=1, 1791 max_nodes=2, 1792 ) 1793 1794 self.assertFalse(handler.use_agent_store) 1795 1796 @patch.object(dist, "PrefixStore") 1797 def test_share_tcp_store_from_backend(self, prefix_store_class_mock): 1798 prefix_store = Mock(spec=dist.PrefixStore) 1799 prefix_store_class_mock.return_value = prefix_store 1800 1801 tcp_store = Mock(spec=dist.TCPStore) 1802 expected_addr = "expected_address" 1803 expected_port = 54321 1804 type(tcp_store).host = PropertyMock(return_value=expected_addr) 1805 type(tcp_store).port = PropertyMock(return_value=expected_port) 1806 # this will be injected 1807 self._store = tcp_store 1808 1809 handler1 = self._create_handler(min_nodes=2, max_nodes=2) 1810 handler2 = self._create_handler(min_nodes=2, max_nodes=2) 1811 1812 handler1_thread = _CapturingThread(target=handler1.next_rendezvous) 1813 handler2_thread = _CapturingThread(target=handler2.next_rendezvous) 1814 1815 handler1_thread.start() 1816 handler2_thread.start() 1817 1818 rdzv_info1: RendezvousInfo = handler1_thread.join() 1819 rdzv_info2: RendezvousInfo = handler2_thread.join() 1820 1821 self.assertEqual(rdzv_info1.store, prefix_store) 1822 self.assertEqual(rdzv_info2.store, prefix_store) 1823 prefix_store_class_mock.assert_called_with( 1824 "torch.rendezvous.dummy_run_id.0", tcp_store 1825 ) 1826 1827 self.assertEqual( 1828 rdzv_info1.bootstrap_store_info, rdzv_info2.bootstrap_store_info 1829 ) 1830 1831 self.assertEqual(rdzv_info1.bootstrap_store_info.master_addr, expected_addr) 1832 self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, expected_port) 1833 1834 @patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"}) 1835 @patch.object(dist, "PrefixStore") 1836 def test_share_tcp_store_is_disabled(self, prefix_store_class_mock): 1837 prefix_store = Mock() 1838 prefix_store_class_mock.return_value = prefix_store 1839 1840 prefix_store.set.return_value = None 1841 prefix_store.get.return_value = b"123" 1842 tcp_store = Mock(spec=dist.TCPStore) 1843 # this will be injected 1844 self._store = tcp_store 1845 1846 handler1 = self._create_handler(min_nodes=2, max_nodes=2) 1847 handler2 = self._create_handler(min_nodes=2, max_nodes=2) 1848 1849 handler1_thread = _CapturingThread(target=handler1.next_rendezvous) 1850 handler2_thread = _CapturingThread(target=handler2.next_rendezvous) 1851 1852 handler1_thread.start() 1853 handler2_thread.start() 1854 1855 rdzv_info1: RendezvousInfo = handler1_thread.join() 1856 rdzv_info2: RendezvousInfo = handler2_thread.join() 1857 1858 self.assertEqual(rdzv_info1.store, prefix_store) 1859 self.assertEqual(rdzv_info2.store, prefix_store) 1860 prefix_store_class_mock.assert_called_with( 1861 "torch.rendezvous.dummy_run_id.0", self._store 1862 ) 1863 self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, 123) 1864 self.assertEqual(rdzv_info2.bootstrap_store_info.master_port, 123) 1865 1866 1867class _InMemoryRendezvousBackend(RendezvousBackend): 1868 def __init__(self) -> None: 1869 self._lock = threading.Lock() 1870 self._state = None 1871 self._token = None 1872 1873 @property 1874 def name(self): 1875 return "_in_memory_backend" 1876 1877 def get_state(self): 1878 with self._lock: 1879 if self._state is None: 1880 return None 1881 return (self._state, self._token) 1882 1883 return self._state 1884 1885 def set_state(self, state, token): 1886 if state is None: 1887 raise ValueError("State cannot be None.") 1888 with self._lock: 1889 if token is None and self._token is not None: 1890 return None 1891 if self._token != token: 1892 return None 1893 1894 self._state = state 1895 self._token = self._token + 1 if self._token is not None else 0 1896