xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import copy
10import os
11import pickle
12import socket
13import threading
14import time
15from abc import ABC, abstractmethod
16from base64 import b64encode
17from datetime import datetime, timedelta
18from typing import Callable, cast, Optional, Tuple
19from unittest import TestCase
20from unittest.mock import call, MagicMock, Mock, patch, PropertyMock
21
22import torch.distributed as dist
23from torch.distributed import HashStore, Store
24from torch.distributed.elastic.rendezvous import (
25    RendezvousClosedError,
26    RendezvousError,
27    RendezvousInfo,
28    RendezvousParameters,
29    RendezvousStateError,
30    RendezvousStoreInfo,
31    RendezvousTimeoutError,
32)
33from torch.distributed.elastic.rendezvous.dynamic_rendezvous import (
34    _Action,
35    _BackendRendezvousStateHolder,
36    _DistributedRendezvousOpExecutor,
37    _NodeDesc,
38    _NodeDescGenerator,
39    _RendezvousCloseOp,
40    _RendezvousContext,
41    _RendezvousExitOp,
42    _RendezvousJoinOp,
43    _RendezvousKeepAliveOp,
44    _RendezvousState,
45    _RendezvousStateHolder,
46    create_handler,
47    DynamicRendezvousHandler,
48    RendezvousBackend,
49    RendezvousSettings,
50    RendezvousTimeout,
51    Token,
52)
53
54
55class CustomAssertMixin:
56    assertDictEqual: Callable
57
58    def assert_state_equal(
59        self, actual: _RendezvousState, expected: _RendezvousState
60    ) -> None:
61        self.assertDictEqual(vars(actual), vars(expected))
62
63    def assert_state_empty(self, actual: _RendezvousState) -> None:
64        self.assertDictEqual(vars(actual), vars(_RendezvousState()))
65
66
67class RendezvousTimeoutTest(TestCase):
68    def test_init_initializes_timeout(self) -> None:
69        timeout = RendezvousTimeout(
70            timedelta(seconds=50),
71            timedelta(seconds=60),
72            timedelta(seconds=70),
73            timedelta(seconds=80),
74        )
75
76        self.assertEqual(timeout.join, timedelta(seconds=50))
77        self.assertEqual(timeout.last_call, timedelta(seconds=60))
78        self.assertEqual(timeout.close, timedelta(seconds=70))
79        self.assertEqual(timeout.heartbeat, timedelta(seconds=80))
80
81    def test_init_initializes_timeout_if_no_timeout_is_specified(self) -> None:
82        timeout = RendezvousTimeout()
83
84        self.assertEqual(timeout.join, timedelta(seconds=600))
85        self.assertEqual(timeout.last_call, timedelta(seconds=30))
86        self.assertEqual(timeout.close, timedelta(seconds=30))
87        self.assertEqual(timeout.heartbeat, timedelta(seconds=5))
88
89    def test_init_raises_error_if_timeout_is_not_positive(self) -> None:
90        join_timeouts = [timedelta(seconds=0), timedelta(seconds=-1)]
91
92        for join_timeout in join_timeouts:
93            with self.subTest(join_timeout=join_timeout):
94                with self.assertRaisesRegex(
95                    ValueError,
96                    rf"^The join timeout \({join_timeout}\) must be positive.$",
97                ):
98                    timeout = RendezvousTimeout(join_timeout)
99
100
101class NodeDescTest(TestCase):
102    def test_repr(self) -> None:
103        desc = _NodeDesc("dummy_fqdn", 3, 5)
104
105        self.assertEqual(repr(desc), "dummy_fqdn_3_5")
106
107    def test_hash(self) -> None:
108        desc1 = _NodeDesc("dummy_fqdn", 2, 4)
109        desc2 = _NodeDesc("dummy_fqdn", 3, 5)
110
111        descs = {desc1, desc2}
112
113        self.assertIn(desc1, descs)
114        self.assertIn(desc2, descs)
115
116
117class NodeDescGeneratorTest(TestCase):
118    def test_generate(self) -> None:
119        desc_generator = _NodeDescGenerator()
120
121        fqdn = socket.getfqdn()
122
123        pid = os.getpid()
124
125        for local_id in range(4):
126            with self.subTest(fqdn=fqdn, pid=pid, local_id=local_id):
127                desc = desc_generator.generate()
128
129                self.assertEqual(repr(desc), f"{fqdn}_{pid}_{local_id}")
130
131
132class RendezvousStateTest(TestCase):
133    def test_encoded_size_is_within_expected_limit(self) -> None:
134        state = _RendezvousState()
135        state.round = 1
136        state.complete = True
137        state.deadline = datetime.utcnow()
138        state.closed = True
139
140        # fmt: off
141        expected_max_sizes = (
142            (   5,    2 * (2 ** 10),),  #    10 machines <=   2KB  # noqa: E201, E241, E262
143            (  50,   16 * (2 ** 10),),  #   100 machines <=  16KB  # noqa: E201, E241, E262
144            ( 500,  160 * (2 ** 10),),  #  1000 machines <= 160KB  # noqa: E201, E241, E262
145            (5000, 1600 * (2 ** 10),),  # 10000 machines <= 1.6MB  # noqa: E201, E241, E262
146        )
147        # fmt: on
148
149        for num_nodes, max_byte_size in expected_max_sizes:
150            with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size):
151                for i in range(num_nodes):
152                    node_running = _NodeDesc(
153                        f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i
154                    )
155                    node_waiting = _NodeDesc(
156                        f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i
157                    )
158
159                    state.participants[node_running] = i
160
161                    state.wait_list.add(node_waiting)
162
163                    state.last_heartbeats[node_running] = datetime.utcnow()
164                    state.last_heartbeats[node_waiting] = datetime.utcnow()
165
166                bits = pickle.dumps(state)
167
168                base64_bits = b64encode(bits)
169
170                self.assertLessEqual(len(base64_bits), max_byte_size)
171
172
173class FakeRendezvousBackend(RendezvousBackend):
174    _state: Optional[bytes]
175    _token: int
176
177    def __init__(self) -> None:
178        self._state = None
179        self._token = 0
180
181    @property
182    def name(self) -> str:
183        return "fake_backend"
184
185    def get_state(self) -> Optional[Tuple[bytes, Token]]:
186        if self._token == 0:
187            return None
188
189        return self._state, self._token  # type: ignore[return-value]
190
191    def set_state(
192        self, state: bytes, token: Optional[Token] = None
193    ) -> Optional[Tuple[bytes, Token, bool]]:
194        if token is None:
195            token = 0
196
197        if token == self._token:
198            self._state = state
199            self._token += 1
200
201            has_set = True
202        else:
203            has_set = False
204
205        return self._state, self._token, has_set  # type: ignore[return-value]
206
207    def get_state_internal(self) -> _RendezvousState:
208        return pickle.loads(cast(bytes, self._state))
209
210    def set_state_internal(self, state: _RendezvousState) -> None:
211        self._state = pickle.dumps(state)
212        self._token += 1
213
214    def corrupt_state(self) -> None:
215        self._state = b"corrupt_state"
216        self._token += 1
217
218
219class BackendRendezvousStateHolderTest(TestCase, CustomAssertMixin):
220    def setUp(self) -> None:
221        self._backend = FakeRendezvousBackend()
222
223        mock_get_state = MagicMock(wraps=self._backend.get_state)
224        mock_set_state = MagicMock(wraps=self._backend.set_state)
225
226        self._mock_backend = Mock()
227        self._mock_backend.get_state = mock_get_state
228        self._mock_backend.set_state = mock_set_state
229
230        setattr(self._backend, "get_state", mock_get_state)  # noqa: B010
231        setattr(self._backend, "set_state", mock_set_state)  # noqa: B010
232
233        self._settings = RendezvousSettings(
234            run_id="dummy_run_id",
235            min_nodes=1,
236            max_nodes=1,
237            timeout=RendezvousTimeout(),
238            keep_alive_interval=timedelta(seconds=30),
239            keep_alive_max_attempt=3,
240        )
241
242        self._cache_duration = 0
243
244        self._now = datetime(2000, 1, 1, hour=0, minute=0)
245
246        self._datetime_patch = patch(
247            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime"
248        )
249
250        mock_datetime = self._datetime_patch.start()
251        mock_datetime.utcnow.return_value = self._now
252
253    def tearDown(self) -> None:
254        self._datetime_patch.stop()
255
256    def _create_state(self) -> _RendezvousState:
257        state = _RendezvousState()
258        state.round = 999
259        state.complete = True
260        state.deadline = self._now
261        state.closed = True
262        state.participants = {
263            _NodeDesc("dummy1", 1, 1): 0,
264            _NodeDesc("dummy2", 1, 1): 1,
265            _NodeDesc("dummy3", 1, 1): 2,
266        }
267        state.wait_list = {
268            _NodeDesc("dummy4", 1, 1),
269            _NodeDesc("dummy5", 1, 1),
270        }
271        state.last_heartbeats = {
272            _NodeDesc("dummy1", 1, 1): self._now,
273            _NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15),
274            _NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30),
275            _NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60),
276            _NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90),
277        }
278
279        return state
280
281    def _create_state_holder(self) -> _BackendRendezvousStateHolder:
282        return _BackendRendezvousStateHolder(
283            self._backend, self._settings, self._cache_duration
284        )
285
286    def test_init_initializes_state_holder(self) -> None:
287        state_holder = self._create_state_holder()
288
289        self.assert_state_empty(state_holder.state)
290
291        self._mock_backend.assert_not_called()
292
293    def test_sync_gets_empty_state_if_backend_state_does_not_exist(self) -> None:
294        state_holder = self._create_state_holder()
295
296        has_set = state_holder.sync()
297
298        self.assertIsNone(has_set)
299
300        self.assert_state_empty(state_holder.state)
301
302        self.assertEqual(self._mock_backend.get_state.call_count, 1)
303        self.assertEqual(self._mock_backend.set_state.call_count, 0)
304
305    def test_sync_gets_backend_state_if_local_state_is_clean(self) -> None:
306        state_holder = self._create_state_holder()
307
308        expected_state = self._create_state()
309
310        for attempt in range(1, 4):
311            with self.subTest(attempt=attempt):
312                expected_state.round = attempt
313
314                self._backend.set_state_internal(expected_state)
315
316                has_set = state_holder.sync()
317
318                self.assertIsNone(has_set)
319
320                self.assert_state_equal(state_holder.state, expected_state)
321
322                self.assertEqual(self._mock_backend.get_state.call_count, 1)
323                self.assertEqual(self._mock_backend.set_state.call_count, 0)
324
325                self._mock_backend.reset_mock()
326
327    def test_sync_gets_backend_state_if_local_state_is_old_and_dirty(self) -> None:
328        state_holder = self._create_state_holder()
329
330        expected_state = self._create_state()
331
332        for attempt in range(1, 4):
333            with self.subTest(attempt=attempt):
334                self._backend.set_state_internal(expected_state)  # Increment token.
335
336                state_holder.state.round = attempt
337                state_holder.mark_dirty()
338
339                has_set = state_holder.sync()
340
341                self.assertFalse(has_set)
342
343                self.assert_state_equal(state_holder.state, expected_state)
344
345                self.assertEqual(self._mock_backend.get_state.call_count, 0)
346                self.assertEqual(self._mock_backend.set_state.call_count, 1)
347
348                self._mock_backend.reset_mock()
349
350    def test_sync_sets_backend_state_if_local_state_is_new_and_dirty(self) -> None:
351        state_holder = self._create_state_holder()
352
353        for attempt in range(1, 4):
354            with self.subTest(attempt=attempt):
355                state_holder.state.round = attempt
356                state_holder.mark_dirty()
357
358                has_set = state_holder.sync()
359
360                self.assertTrue(has_set)
361
362                expected_state = self._backend.get_state_internal()
363
364                self.assert_state_equal(state_holder.state, expected_state)
365
366                self.assertEqual(self._mock_backend.get_state.call_count, 0)
367                self.assertEqual(self._mock_backend.set_state.call_count, 1)
368
369                self._mock_backend.reset_mock()
370
371    def test_sync_uses_cached_state_if_cache_duration_is_specified(self) -> None:
372        state = self._create_state()
373
374        self._backend.set_state_internal(state)
375
376        with patch(
377            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time"
378        ) as mock_time:
379            for cache_duration in [1, 5, 10]:
380                with self.subTest(cache_duration=cache_duration):
381                    self._cache_duration = cache_duration
382
383                    state_holder = self._create_state_holder()
384
385                    mock_time.monotonic.return_value = 5
386
387                    state_holder.sync()
388
389                    has_set = state_holder.sync()
390
391                    self.assertIsNone(has_set)
392
393                    self.assertEqual(self._mock_backend.get_state.call_count, 1)
394                    self.assertEqual(self._mock_backend.set_state.call_count, 0)
395
396                    mock_time.monotonic.return_value = 5 + self._cache_duration
397
398                    state_holder.sync()
399
400                    has_set = state_holder.sync()
401
402                    self.assertIsNone(has_set)
403
404                    self.assertEqual(self._mock_backend.get_state.call_count, 1)
405                    self.assertEqual(self._mock_backend.set_state.call_count, 0)
406
407                    self._mock_backend.get_state.reset_mock()
408
409    def test_sync_gets_backend_state_if_cached_state_has_expired(self) -> None:
410        state = self._create_state()
411
412        self._backend.set_state_internal(state)
413
414        with patch(
415            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time"
416        ) as mock_time:
417            self._cache_duration = 1
418
419            state_holder = self._create_state_holder()
420
421            mock_time.monotonic.return_value = 5
422
423            state_holder.sync()
424
425            has_set = state_holder.sync()
426
427            self.assertIsNone(has_set)
428
429            self.assertEqual(self._mock_backend.get_state.call_count, 1)
430            self.assertEqual(self._mock_backend.set_state.call_count, 0)
431
432            mock_time.monotonic.return_value = 5 + self._cache_duration + 0.01
433
434            state_holder.sync()
435
436            has_set = state_holder.sync()
437
438            self.assertIsNone(has_set)
439
440            self.assertEqual(self._mock_backend.get_state.call_count, 2)
441            self.assertEqual(self._mock_backend.set_state.call_count, 0)
442
443    def test_sync_sanitizes_state(self) -> None:
444        state = self._create_state()
445
446        expected_state = copy.deepcopy(state)
447
448        dead_node1 = _NodeDesc("dead1", 1, 1)
449        dead_node2 = _NodeDesc("dead2", 1, 1)
450        dead_node3 = _NodeDesc("dead3", 1, 1)
451        dead_node4 = _NodeDesc("dead4", 1, 1)
452        dead_node5 = _NodeDesc("dead5", 1, 1)
453
454        state.last_heartbeats[dead_node1] = self._now - timedelta(seconds=91)
455        state.last_heartbeats[dead_node2] = self._now - timedelta(seconds=100)
456        state.last_heartbeats[dead_node3] = self._now - timedelta(seconds=110)
457        state.last_heartbeats[dead_node4] = self._now - timedelta(seconds=120)
458        state.last_heartbeats[dead_node5] = self._now - timedelta(seconds=130)
459
460        state.participants[dead_node1] = 0
461        state.participants[dead_node2] = 0
462        state.participants[dead_node3] = 0
463
464        state.wait_list.add(dead_node4)
465        state.wait_list.add(dead_node5)
466
467        self._backend.set_state_internal(state)
468
469        state_holder = self._create_state_holder()
470
471        state_holder.sync()
472
473        self.assert_state_equal(state_holder.state, expected_state)
474
475    def test_sync_sanitizes_state_if_no_participants_is_left(self) -> None:
476        state = self._create_state()
477
478        expected_state = copy.deepcopy(state)
479
480        for node in state.last_heartbeats:
481            state.last_heartbeats[node] = self._now - timedelta(seconds=100)
482
483        expected_state.complete = False
484        expected_state.round = 1000
485        expected_state.participants = {}
486        expected_state.wait_list = set()
487        expected_state.last_heartbeats = {}
488
489        self._backend.set_state_internal(state)
490
491        state_holder = self._create_state_holder()
492
493        state_holder.sync()
494
495        self.assert_state_equal(state_holder.state, expected_state)
496
497    def test_sync_raises_error_if_backend_state_is_corrupt(self) -> None:
498        self._backend.corrupt_state()
499
500        state_holder = self._create_state_holder()
501
502        with self.assertRaisesRegex(
503            RendezvousStateError,
504            r"^The rendezvous state is corrupt. See inner exception for details.$",
505        ):
506            state_holder.sync()
507
508
509class FakeRendezvousStateHolder(_RendezvousStateHolder):
510    _state: _RendezvousState
511    _dirty: Optional[bool]
512
513    def __init__(self) -> None:
514        self._state = _RendezvousState()
515        self._dirty = None
516
517    @property
518    def state(self) -> _RendezvousState:
519        return self._state
520
521    @state.setter
522    def state(self, value) -> None:
523        self._state = value
524
525    def sync(self) -> Optional[bool]:
526        self._dirty, dirty = None, self._dirty
527
528        return dirty
529
530    def mark_dirty(self) -> None:
531        self._dirty = True
532
533
534class DistributedRendezvousOpExecutorTest(TestCase, CustomAssertMixin):
535    def setUp(self) -> None:
536        self._node = _NodeDesc("this_node", 1, 1)
537
538        self._state_holder = FakeRendezvousStateHolder()
539
540        mock_sync = MagicMock(wraps=self._state_holder.sync)
541        mock_mark = MagicMock(wraps=self._state_holder.mark_dirty)
542
543        self._mock_state_holder = Mock()
544        self._mock_state_holder.sync = mock_sync
545        self._mock_state_holder.mark = mock_mark
546
547        setattr(self._state_holder, "sync", mock_sync)  # noqa: B010
548        setattr(self._state_holder, "mark_dirty", mock_mark)  # noqa: B010
549
550        self._state = self._state_holder.state
551
552        self._min_nodes = 1
553        self._max_nodes = 1
554
555        self._timeout = RendezvousTimeout()
556
557        self._now = datetime(2000, 1, 1, hour=0, minute=0)
558
559        self._datetime_patch = patch(
560            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime"
561        )
562
563        mock_datetime = self._datetime_patch.start()
564        mock_datetime.utcnow.return_value = self._now
565
566    def tearDown(self) -> None:
567        self._datetime_patch.stop()
568
569    def _create_settings(self) -> RendezvousSettings:
570        return RendezvousSettings(
571            run_id="dummy_run_id",
572            min_nodes=self._min_nodes,
573            max_nodes=self._max_nodes,
574            timeout=self._timeout,
575            keep_alive_interval=timedelta(seconds=30),
576            keep_alive_max_attempt=3,
577        )
578
579    def _create_op_executor(
580        self, settings: Optional[RendezvousSettings] = None
581    ) -> _DistributedRendezvousOpExecutor:
582        self._state_holder.state = self._state
583
584        if settings is None:
585            settings = self._create_settings()
586
587        return _DistributedRendezvousOpExecutor(
588            self._node, self._state_holder, settings
589        )
590
591    def _run_action(self, action: _Action) -> None:
592        op_executor = self._create_op_executor()
593
594        op = MagicMock(side_effect=[action, _Action.FINISH])
595
596        op_executor.run(op, deadline=1)
597
598    def _assert_action(self, action: _Action, expected_state: _RendezvousState) -> None:
599        self._run_action(action)
600
601        self.assert_state_equal(self._state, expected_state)
602
603        self.assertListEqual(
604            self._mock_state_holder.mock_calls, [call.sync(), call.mark(), call.sync()]
605        )
606
607    def test_run_passes_expected_context_and_deadline_to_state_handler(self) -> None:
608        settings = self._create_settings()
609
610        op_executor = self._create_op_executor(settings)
611
612        op = MagicMock(return_value=_Action.FINISH)
613
614        op_executor.run(op, deadline=3)
615
616        ctx, deadline = op.call_args[0]  # args
617
618        self.assertIs(ctx.node, self._node)
619        self.assertIs(ctx.state, self._state)
620        self.assertIs(ctx.settings, settings)
621
622        self.assertEqual(deadline, 3)
623
624    def test_run_keeps_alive(self) -> None:
625        expected_state = _RendezvousState()
626
627        expected_state.last_heartbeats[self._node] = self._now
628
629        self._assert_action(_Action.KEEP_ALIVE, expected_state)
630
631    def test_run_adds_to_participants(self) -> None:
632        expected_state = _RendezvousState()
633
634        expected_state.participants[self._node] = 0
635
636        expected_state.last_heartbeats[self._node] = self._now
637
638        self._min_nodes = 2
639        self._max_nodes = 2
640
641        self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
642
643    def test_run_adds_to_participants_if_node_was_in_waitlist(self) -> None:
644        self._state.wait_list.add(self._node)
645
646        expected_state = _RendezvousState()
647
648        expected_state.participants[self._node] = 0
649
650        expected_state.last_heartbeats[self._node] = self._now
651
652        self._min_nodes = 2
653        self._max_nodes = 2
654
655        self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
656
657    def _add_participants(
658        self, num_participants: int, state: _RendezvousState, ranked: bool = False
659    ) -> None:
660        for i in range(num_participants):
661            if ranked:
662                node = _NodeDesc(f"dummy{i}", 1, 1)
663                rank = i
664            else:
665                node = _NodeDesc(
666                    f"dummy{num_participants - i - 1}", 1, 1
667                )  # Add in reverse.
668                rank = 0
669
670            state.participants[node] = rank
671
672            state.last_heartbeats[node] = self._now
673
674    def test_run_adds_to_participants_and_starts_last_call_if_min_nodes_is_reached(
675        self,
676    ) -> None:
677        for num_participants in range(3):
678            self._state = _RendezvousState()
679
680            self._add_participants(num_participants, self._state)
681
682            self._state.wait_list.add(self._node)
683
684            expected_state = _RendezvousState()
685
686            self._add_participants(num_participants, expected_state)
687
688            expected_state.participants[self._node] = 0
689
690            expected_state.last_heartbeats[self._node] = self._now
691
692            expected_state.deadline = self._now + self._timeout.last_call
693
694            with self.subTest(num_participants=num_participants):
695                self._min_nodes = num_participants + 1
696                self._max_nodes = num_participants + 2
697
698                self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
699
700                self._mock_state_holder.reset_mock()
701
702    def test_run_adds_to_participants_and_completes_rendezvous_if_max_nodes_is_reached(
703        self,
704    ) -> None:
705        for min_max_nodes_equal in [False, True]:
706            for num_participants in range(3):
707                rank = num_participants
708
709                self._state = _RendezvousState()
710
711                self._add_participants(num_participants, self._state)
712
713                self._state.wait_list.add(self._node)
714
715                self._state.deadline = self._now + self._timeout.last_call
716
717                expected_state = _RendezvousState()
718
719                self._add_participants(num_participants, expected_state, ranked=True)
720
721                expected_state.participants[self._node] = rank
722
723                expected_state.last_heartbeats[self._node] = self._now
724
725                expected_state.complete = True
726                expected_state.deadline = None
727
728                with self.subTest(num_participants=num_participants):
729                    self._min_nodes = num_participants + 1 if min_max_nodes_equal else 0
730                    self._max_nodes = num_participants + 1
731
732                    self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
733
734                    self._mock_state_holder.reset_mock()
735
736    def test_run_adds_to_waitlist(self) -> None:
737        expected_state = _RendezvousState()
738
739        expected_state.wait_list.add(self._node)
740
741        expected_state.last_heartbeats[self._node] = self._now
742
743        self._assert_action(_Action.ADD_TO_WAIT_LIST, expected_state)
744
745    def test_run_removes_from_participants(self) -> None:
746        for complete, last_call_deadline in [(False, self._now), (True, None)]:
747            self._state = _RendezvousState()
748
749            self._add_participants(2, self._state)
750
751            self._state.participants[self._node] = 0
752
753            self._state.last_heartbeats[self._node] = self._now
754
755            self._state.complete = complete
756            self._state.deadline = last_call_deadline
757
758            self._state.round = 1
759
760            expected_state = _RendezvousState()
761
762            self._add_participants(2, expected_state)
763
764            expected_state.complete = complete
765            expected_state.deadline = last_call_deadline
766
767            expected_state.round = 1
768
769            with self.subTest(complete=complete):
770                self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
771
772                self._mock_state_holder.reset_mock()
773
774    def test_run_removes_from_participants_and_moves_to_next_round_if_node_is_last_participant(
775        self,
776    ) -> None:
777        self._state.participants[self._node] = 0
778
779        self._state.last_heartbeats[self._node] = self._now
780
781        self._state.complete = True
782
783        self._state.round = 1
784
785        expected_state = _RendezvousState()
786
787        expected_state.complete = False
788
789        expected_state.round = 2
790
791        self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
792
793    def test_run_removes_from_participants_and_clears_last_call_if_rendezvous_has_less_than_min_nodes(
794        self,
795    ) -> None:
796        self._add_participants(2, self._state)
797
798        self._state.participants[self._node] = 0
799
800        self._state.last_heartbeats[self._node] = self._now
801
802        self._state.deadline = self._now
803
804        expected_state = _RendezvousState()
805
806        self._add_participants(2, expected_state)
807
808        self._min_nodes = 3
809        self._max_nodes = 4
810
811        self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
812
813    def test_run_removes_from_waitlist(self) -> None:
814        self._state.wait_list.add(self._node)
815
816        self._state.last_heartbeats[self._node] = self._now
817
818        expected_state = _RendezvousState()
819
820        self._assert_action(_Action.REMOVE_FROM_WAIT_LIST, expected_state)
821
822    def test_run_marks_rendezvous_closed(self) -> None:
823        expected_state = _RendezvousState()
824
825        expected_state.closed = True
826
827        self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED, expected_state)
828
829    def test_run_raises_error_if_rendezvous_is_closed(self) -> None:
830        with self.assertRaises(RendezvousClosedError):
831            self._run_action(_Action.ERROR_CLOSED)
832
833        self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()])
834
835    def test_run_raises_error_if_operation_timed_out(self) -> None:
836        with self.assertRaises(RendezvousTimeoutError):
837            self._run_action(_Action.ERROR_TIMEOUT)
838
839        self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()])
840
841    def test_run_delays_execution_if_sync_requested(self) -> None:
842        with patch(
843            "torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay"
844        ) as mock_delay:
845            self._run_action(_Action.SYNC)
846
847            mock_delay.assert_called_once_with(seconds=1)
848
849        self.assertListEqual(
850            self._mock_state_holder.mock_calls, [call.sync(), call.sync()]
851        )
852
853
854class AbstractTestRendezvousOp(ABC):
855    assertEqual: Callable
856
857    def setUp(self) -> None:
858        self._node = _NodeDesc("this_node", 1, 1)
859
860        self._min_nodes = 1
861        self._max_nodes = 2
862
863        self._keep_alive_interval = timedelta(seconds=30)
864
865        self._state = _RendezvousState()
866        self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1
867
868        self._now = datetime(2000, 1, 1, hour=0, minute=0)
869
870        self._deadline = 10
871
872        self._datetime_patch = patch(
873            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime"
874        )
875
876        mock_datetime = self._datetime_patch.start()
877        mock_datetime.utcnow.return_value = self._now
878
879        self._time_patch = patch(
880            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time"
881        )
882
883        mock_time = self._time_patch.start()
884        mock_time.monotonic.return_value = self._deadline
885
886    def tearDown(self) -> None:
887        self._time_patch.stop()
888        self._datetime_patch.stop()
889
890    def _get_next_action(self) -> _Action:
891        op = self._create_op()
892
893        settings = RendezvousSettings(
894            run_id="dummy_run_id",
895            min_nodes=self._min_nodes,
896            max_nodes=self._max_nodes,
897            timeout=RendezvousTimeout(),
898            keep_alive_interval=self._keep_alive_interval,
899            keep_alive_max_attempt=3,
900        )
901
902        ctx = _RendezvousContext(self._node, self._state, settings)
903
904        return op(ctx, self._deadline)
905
906    @abstractmethod
907    def _create_op(self) -> Callable:
908        pass
909
910    def _assert_action(self, expected_action) -> None:
911        action = self._get_next_action()
912
913        self.assertEqual(action, expected_action)
914
915
916class TestRendezvousExitOp(AbstractTestRendezvousOp, TestCase):
917    def _create_op(self) -> Callable:
918        return _RendezvousExitOp()
919
920    def test_removes_from_participants_if_node_is_participant(self) -> None:
921        self._state.participants[self._node] = 1
922
923        self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS)
924
925    def test_raises_timeout_if_deadline_exceeded(self) -> None:
926        self._deadline = 0
927
928        self._state.participants[self._node] = 1
929
930        self._assert_action(_Action.ERROR_TIMEOUT)
931
932    def test_finishes_if_node_is_not_participant(self) -> None:
933        self._assert_action(_Action.FINISH)
934
935
936class TestRendezvousJoinOp(AbstractTestRendezvousOp, TestCase):
937    def _create_op(self) -> Callable:
938        return _RendezvousJoinOp()
939
940    def test_raises_closed_if_rendezvous_is_closed(self) -> None:
941        self._state.closed = True
942
943        self._assert_action(_Action.ERROR_CLOSED)
944
945    def test_finishes_if_rendezvous_is_complete_and_node_is_participant(self) -> None:
946        self._state.participants[self._node] = 0
947
948        self._state.complete = True
949
950        self._assert_action(_Action.FINISH)
951
952    def _assert_waits_rendezvous_completion(self) -> None:
953        keep_alive_time = self._now - self._keep_alive_interval
954
955        for delta, expected_action in [
956            (timedelta(seconds=0), _Action.KEEP_ALIVE),
957            (timedelta(seconds=1), _Action.SYNC),
958        ]:
959            self._state.last_heartbeats[self._node] = keep_alive_time + delta
960
961            self._assert_action(expected_action)
962
963    def test_treat_as_redundancy_for_next_rendezvous_if_rendezvous_is_complete(
964        self,
965    ) -> None:
966        self._max_nodes = 1
967
968        self._state.complete = True
969
970        self._assert_action(_Action.ADD_TO_REDUNDANCY_LIST)
971
972    def test_waits_next_round_if_rendezvous_is_complete_and_node_is_redundant(
973        self,
974    ) -> None:
975        self._state.redundancy_list.add(self._node)
976
977        self._max_nodes = 1
978
979        self._state.complete = True
980
981        self._assert_waits_rendezvous_completion()
982
983    def test_remove_from_rednundancy_list(self) -> None:
984        self._state.redundancy_list.add(self._node)
985
986        self._max_nodes = 2
987
988        self._state.complete = True
989
990        self._assert_action(_Action.REMOVE_FROM_REDUNDANCY_LIST)
991
992    def test_waits_next_round_if_rendezvous_is_complete_and_node_is_in_wait_list(
993        self,
994    ) -> None:
995        self._state.wait_list.add(self._node)
996
997        self._state.complete = True
998
999        self._assert_waits_rendezvous_completion()
1000
1001    def test_adds_to_wait_list_if_rendezvous_is_complete_and_num_nodes_is_less_than_max_nodes(
1002        self,
1003    ) -> None:
1004        self._state.complete = True
1005
1006        self._assert_action(_Action.ADD_TO_WAIT_LIST)
1007
1008    def test_waits_rendezvous_to_complete_if_node_is_participant(self) -> None:
1009        self._max_nodes = 3
1010
1011        self._state.participants[self._node] = 0
1012
1013        self._state.deadline = self._now
1014
1015        self._assert_waits_rendezvous_completion()
1016
1017    def test_marks_rendezvous_complete_if_node_is_participant_and_last_call_deadline_exceeded(
1018        self,
1019    ) -> None:
1020        self._max_nodes = 3
1021
1022        self._state.participants[self._node] = 0
1023
1024        self._state.deadline = self._now - timedelta(seconds=1)
1025
1026        self._assert_action(_Action.MARK_RENDEZVOUS_COMPLETE)
1027
1028    def test_adds_to_participants(self) -> None:
1029        self._assert_action(_Action.ADD_TO_PARTICIPANTS)
1030
1031    def test_raises_timeout_if_deadline_exceeded(self) -> None:
1032        self._deadline = 0
1033
1034        self._assert_action(_Action.ERROR_TIMEOUT)
1035
1036    def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_participant(
1037        self,
1038    ) -> None:
1039        self._deadline = 0
1040
1041        self._state.participants[self._node] = 0
1042
1043        self._assert_action(_Action.ERROR_TIMEOUT)
1044
1045    def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_in_wait_list(
1046        self,
1047    ) -> None:
1048        self._deadline = 0
1049
1050        self._state.wait_list.add(self._node)
1051
1052        self._assert_action(_Action.ERROR_TIMEOUT)
1053
1054    def test_removes_from_participants_if_timed_out_but_rollback_deadline_is_not_reached(
1055        self,
1056    ) -> None:
1057        self._deadline = 5
1058
1059        self._state.participants[self._node] = 0
1060
1061        self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS)
1062
1063    def test_removes_from_wait_list_if_timed_out_but_rollback_deadline_is_not_reached(
1064        self,
1065    ) -> None:
1066        self._deadline = 5
1067
1068        self._state.wait_list.add(self._node)
1069
1070        self._assert_action(_Action.REMOVE_FROM_WAIT_LIST)
1071
1072    def test_no_timeout_for_redundant_node(self) -> None:
1073        self._max_nodes = 1
1074        self._deadline = 0
1075        self._state.complete = True
1076
1077        self._state.redundancy_list.add(self._node)
1078
1079        self._assert_action(_Action.SYNC)
1080
1081    def test_keep_alive_for_redundant_node(self) -> None:
1082        self._deadline = 0
1083        self._max_nodes = 1
1084        self._state.complete = True
1085
1086        self._state.redundancy_list.add(self._node)
1087
1088        keep_alive_time = self._now - self._keep_alive_interval
1089        self._state.last_heartbeats[self._node] = keep_alive_time
1090        self._assert_action(_Action.KEEP_ALIVE)
1091
1092
1093class TestRendezvousCloseOp(AbstractTestRendezvousOp, TestCase):
1094    def _create_op(self) -> Callable:
1095        return _RendezvousCloseOp()
1096
1097    def test_finishes_if_rendezvous_is_closed(self) -> None:
1098        self._state.closed = True
1099
1100        self._assert_action(_Action.FINISH)
1101
1102    def test_raises_timeout_if_deadline_exceeded(self) -> None:
1103        self._deadline = 0
1104
1105        self._assert_action(_Action.ERROR_TIMEOUT)
1106
1107    def test_marks_rendezvous_closed(self) -> None:
1108        self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED)
1109
1110
1111class TestRendezvousKeepAliveOp(AbstractTestRendezvousOp, TestCase):
1112    def _create_op(self) -> Callable:
1113        return _RendezvousKeepAliveOp()
1114
1115    def test_updates_keep_alive_if_needed(self) -> None:
1116        keep_alive_time = self._now - self._keep_alive_interval
1117
1118        for delta in [timedelta(seconds=0), timedelta(seconds=-1)]:
1119            with self.subTest(delta=delta):
1120                self._state.last_heartbeats[self._node] = keep_alive_time + delta
1121
1122                self._assert_action(_Action.KEEP_ALIVE)
1123
1124    def test_raises_timeout_if_deadlined_exceeded(self) -> None:
1125        self._deadline = 0
1126
1127        self._state.last_heartbeats[self._node] = self._now - self._keep_alive_interval
1128
1129        self._assert_action(_Action.ERROR_TIMEOUT)
1130
1131    def test_finishes_if_no_keep_alive_update_is_needed(self) -> None:
1132        delta = timedelta(seconds=1)
1133
1134        self._state.last_heartbeats[self._node] = (
1135            self._now - self._keep_alive_interval + delta
1136        )
1137
1138        self._assert_action(_Action.FINISH)
1139
1140
1141class DummyStore(Store):
1142    pass
1143
1144
1145class DynamicRendezvousHandlerTest(TestCase):
1146    def setUp(self) -> None:
1147        self._node = _NodeDesc("this_node", 1, 1)
1148
1149        self._min_nodes = 1
1150        self._max_nodes = 1
1151
1152        self._join_timeout: Optional[timedelta] = None
1153        self._close_timeout: Optional[timedelta] = None
1154        self._heartbeat_timeout: Optional[timedelta] = None
1155
1156        self._keep_alive_interval = timedelta(seconds=30)
1157
1158        self._store = DummyStore()
1159
1160        self._mock_store_get = MagicMock(return_value=b"123")
1161        self._mock_store_set = MagicMock()
1162
1163        setattr(self._store, "get", self._mock_store_get)  # noqa: B010
1164        setattr(self._store, "set", self._mock_store_set)  # noqa: B010
1165
1166        self._state_holder = FakeRendezvousStateHolder()
1167
1168        self._mock_sync = MagicMock(wraps=self._state_holder.sync)
1169
1170        setattr(self._state_holder, "sync", self._mock_sync)  # noqa: B010
1171
1172        self._state = self._state_holder.state
1173
1174        self._tcp_store_mock = DummyStore()
1175
1176        patcher = patch.object(
1177            DynamicRendezvousHandler,
1178            "_create_tcp_store_server",
1179            return_value=self._tcp_store_mock,
1180        )
1181        patcher.start()
1182        self.addCleanup(patcher.stop)
1183
1184    def _create_handler(self) -> DynamicRendezvousHandler:
1185        settings = RendezvousSettings(
1186            run_id="dummy_run_id",
1187            min_nodes=self._min_nodes,
1188            max_nodes=self._max_nodes,
1189            timeout=RendezvousTimeout(
1190                join=self._join_timeout,
1191                close=self._close_timeout,
1192                heartbeat=self._heartbeat_timeout,
1193            ),
1194            keep_alive_interval=self._keep_alive_interval,
1195            keep_alive_max_attempt=3,
1196        )
1197
1198        self._state_holder.state = self._state
1199
1200        return DynamicRendezvousHandler(
1201            self._node, settings, "dummy_backend", self._store, self._state_holder
1202        )
1203
1204    def test_share_store_creates_tcp_store(self):
1205        handler = self._create_handler()
1206
1207        shared_store_info = RendezvousStoreInfo("host", 54321)
1208        with patch.object(RendezvousStoreInfo, "build", return_value=shared_store_info):
1209            rdzv_info = handler.next_rendezvous()
1210            self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host")
1211            self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321)
1212        self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock)
1213
1214        rdzv_info = handler.next_rendezvous()
1215        self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock)
1216
1217    def test_share_store_when_tcp_store(self):
1218        handler = self._create_handler()
1219
1220        with patch.object(dist, "PrefixStore", new=Mock):
1221            handler._store = Mock(spec=dist.TCPStore)
1222            type(handler._store).host = PropertyMock(return_value="host")
1223            type(handler._store).port = PropertyMock(return_value=54321)
1224            rdzv_info = handler.next_rendezvous()
1225            self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host")
1226            self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321)
1227            self.assertEqual(handler._shared_tcp_store_server, handler._store)
1228
1229            rdzv_info = handler.next_rendezvous()
1230            self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host")
1231            self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321)
1232            self.assertEqual(handler._shared_tcp_store_server, handler._store)
1233
1234    @patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay")
1235    def test_next_rendezvous_skews_the_first_join_attempt(self, mock_delay) -> None:
1236        for round, expected_call_count in [(0, True), (1, False)]:
1237            with self.subTest(round=round):
1238                self._state.round = round
1239
1240                handler = self._create_handler()
1241
1242                handler.next_rendezvous()
1243
1244                self.assertEqual(mock_delay.call_count, expected_call_count)
1245
1246                mock_delay.reset_mock()
1247
1248    def test_next_rendezvous_returns_expected_value(self) -> None:
1249        self._state.participants[_NodeDesc("dummy1", 1, 1)] = 0
1250        self._state.participants[_NodeDesc("dummy2", 1, 1)] = 0
1251
1252        self._max_nodes = 3
1253
1254        handler = self._create_handler()
1255
1256        rdzv_info = handler.next_rendezvous()
1257
1258        self.assertEqual(rdzv_info.rank, 2)
1259        self.assertEqual(rdzv_info.world_size, 3)
1260
1261        _ = rdzv_info.store.get("dummy_key")
1262
1263        self._mock_store_get.assert_called_with(
1264            "torch.rendezvous.dummy_run_id.0/dummy_key"
1265        )
1266
1267    def test_next_rendezvous_respects_the_requested_timeout(self) -> None:
1268        self._mock_sync.side_effect = lambda: time.sleep(0.3)
1269
1270        self._join_timeout = timedelta(seconds=0.2)
1271
1272        handler = self._create_handler()
1273
1274        with self.assertRaises(RendezvousTimeoutError):
1275            handler.next_rendezvous()
1276
1277    def test_next_rendezvous_moves_to_next_round_if_called_repeatedly(self) -> None:
1278        handler = self._create_handler()
1279
1280        for i in range(4):
1281            handler.next_rendezvous()
1282
1283            self.assertEqual(self._state.round, i)
1284
1285    def test_is_closed_returns_expected_value(self) -> None:
1286        for closed in [False, True]:
1287            with self.subTest(closed=closed):
1288                self._state.closed = closed
1289
1290                handler = self._create_handler()
1291
1292                self.assertEqual(handler.is_closed(), closed)
1293
1294                self._mock_sync.assert_called_once()
1295
1296                self._mock_sync.reset_mock()
1297
1298    @patch("torch.distributed.elastic.events.record_rdzv_event")
1299    def test_is_closed_records_and_raises_exceptions(self, record_mock) -> None:
1300        self._mock_sync.side_effect = RendezvousError("test error")
1301        handler = self._create_handler()
1302        with self.assertRaises(RendezvousError):
1303            handler.is_closed()
1304            record_mock.assert_called_once()
1305
1306    def test_set_closed_closes_rendezvous(self) -> None:
1307        handler = self._create_handler()
1308
1309        handler.set_closed()
1310
1311        self.assertTrue(self._state.closed)
1312
1313    def test_set_closed_respects_the_requested_timeout(self) -> None:
1314        self._mock_sync.side_effect = lambda: time.sleep(0.3)
1315
1316        self._close_timeout = timedelta(seconds=0.2)
1317
1318        handler = self._create_handler()
1319
1320        with self.assertRaises(RendezvousTimeoutError):
1321            handler.set_closed()
1322
1323    def test_set_closed_can_be_called_multiple_times(self) -> None:
1324        handler = self._create_handler()
1325
1326        handler.set_closed()
1327        handler.set_closed()
1328
1329        self.assertTrue(self._state.closed)
1330
1331    @patch("torch.distributed.elastic.events.record_rdzv_event")
1332    def test_set_closed_records_and_raises_exceptions(self, record_mock) -> None:
1333        with patch.object(DynamicRendezvousHandler, "_close") as close_mock:
1334            close_mock.side_effect = RendezvousError("test error")
1335            handler = self._create_handler()
1336            with self.assertRaises(RendezvousError):
1337                handler.set_closed()
1338                record_mock.assert_called_once()
1339
1340    def test_num_nodes_waiting_returns_expected_value(self) -> None:
1341        self._state.wait_list.add(_NodeDesc("dummy1", 1, 1))
1342        self._state.wait_list.add(_NodeDesc("dummy2", 1, 1))
1343
1344        handler = self._create_handler()
1345
1346        self.assertEqual(handler.num_nodes_waiting(), 2)
1347
1348        self._mock_sync.assert_called_once()
1349
1350    @patch("torch.distributed.elastic.events.record_rdzv_event")
1351    def test_num_nodes_waiting_records_and_raises_exceptions(self, record_mock) -> None:
1352        self._mock_sync.side_effect = RendezvousError("test error")
1353        handler = self._create_handler()
1354        with self.assertRaises(RendezvousError):
1355            handler.num_nodes_waiting()
1356            record_mock.assert_called_once()
1357
1358    def test_shutdown_closes_rendezvous_and_returns_true(self) -> None:
1359        handler = self._create_handler()
1360
1361        result = handler.shutdown()
1362
1363        self.assertTrue(result)
1364
1365        self.assertTrue(self._state.closed)
1366
1367    def test_shutdown_returns_false_if_rendezvous_cannot_be_closed(self) -> None:
1368        self._mock_sync.side_effect = [RendezvousError]
1369
1370        handler = self._create_handler()
1371
1372        result = handler.shutdown()
1373
1374        self.assertFalse(result)
1375
1376    def test_shutdown_can_be_called_multiple_times(self) -> None:
1377        handler = self._create_handler()
1378
1379        handler.shutdown()
1380        handler.shutdown()
1381
1382        self.assertTrue(self._state.closed)
1383
1384    @patch("torch.distributed.elastic.events.record_rdzv_event")
1385    def test_shutdown_records_and_raises_exceptions(self, record_mock) -> None:
1386        with patch.object(DynamicRendezvousHandler, "_close") as close_mock:
1387            close_mock.side_effect = RuntimeError("test error")
1388            handler = self._create_handler()
1389            with self.assertRaises(RuntimeError):
1390                handler.shutdown()
1391                record_mock.assert_called_once()
1392
1393    @patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime")
1394    def test_keep_alive_updates_last_heartbeat(self, mock_datetime) -> None:
1395        now = datetime(2000, 1, 1, hour=0, minute=0)
1396
1397        mock_datetime.utcnow.return_value = now
1398
1399        self._state.last_heartbeats[self._node] = now - (self._keep_alive_interval * 2)
1400
1401        handler = self._create_handler()
1402
1403        handler._keep_alive()
1404
1405        self.assertEqual(self._state.last_heartbeats[self._node], now)
1406
1407    def _assert_keep_alive_swallows_rendezvous_errors(self) -> None:
1408        last_heartbeat_time = datetime.utcnow() - (self._keep_alive_interval * 2)
1409
1410        self._state.last_heartbeats[self._node] = last_heartbeat_time
1411
1412        handler = self._create_handler()
1413
1414        handler._keep_alive()
1415
1416        self.assertEqual(self._state.last_heartbeats[self._node], last_heartbeat_time)
1417
1418    def test_keep_alive_swallows_rendezvous_errors(self) -> None:
1419        self._mock_sync.side_effect = [RendezvousError]
1420
1421        self._assert_keep_alive_swallows_rendezvous_errors()
1422
1423    def test_keep_alive_respects_the_requested_timeout(self) -> None:
1424        self._mock_sync.side_effect = lambda: time.sleep(0.3)
1425
1426        self._heartbeat_timeout = timedelta(seconds=0.2)
1427
1428        self._assert_keep_alive_swallows_rendezvous_errors()
1429
1430    def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_shutdown(
1431        self,
1432    ) -> None:
1433        self._node = _NodeDesc("this_node", 1, 2)
1434
1435        name = "RendezvousKeepAliveTimer_2"
1436
1437        handler = self._create_handler()
1438
1439        self.assertTrue(all(t.name != name for t in threading.enumerate()))
1440
1441        handler.next_rendezvous()
1442
1443        self.assertTrue(any(t.name == name for t in threading.enumerate()))
1444
1445        handler.shutdown()
1446
1447        self.assertTrue(all(t.name != name for t in threading.enumerate()))
1448
1449    def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_finalizer(
1450        self,
1451    ) -> None:
1452        self._node = _NodeDesc("this_node", 1, 3)
1453
1454        name = "RendezvousKeepAliveTimer_3"
1455
1456        handler = self._create_handler()
1457
1458        self.assertTrue(all(t.name != name for t in threading.enumerate()))
1459
1460        handler.next_rendezvous()
1461
1462        self.assertTrue(any(t.name == name for t in threading.enumerate()))
1463
1464        del handler
1465
1466        self.assertTrue(all(t.name != name for t in threading.enumerate()))
1467
1468
1469class DummyRendezvousBackend(RendezvousBackend):
1470    @property
1471    def name(self):
1472        return "dummy_backend"
1473
1474    def get_state(self):
1475        return None
1476
1477    def set_state(self, state, token):
1478        return None
1479
1480
1481class DynamicRendezvousHandlerFromBackendTest(TestCase):
1482    def setUp(self) -> None:
1483        self._run_id = "dummy_run_id"
1484        self._store = DummyStore()
1485        self._backend = DummyRendezvousBackend()
1486        self._min_nodes = 3
1487        self._max_nodes = 6
1488        self._timeout: Optional[RendezvousTimeout] = RendezvousTimeout()
1489
1490    def _create_handler(self) -> DynamicRendezvousHandler:
1491        return DynamicRendezvousHandler.from_backend(
1492            run_id=self._run_id,
1493            store=self._store,
1494            backend=self._backend,
1495            min_nodes=self._min_nodes,
1496            max_nodes=self._max_nodes,
1497            timeout=self._timeout,
1498        )
1499
1500    def test_init_initializes_handler(self) -> None:
1501        handler = self._create_handler()
1502
1503        self.assertEqual(handler.get_backend(), self._backend.name)
1504        self.assertEqual(handler.get_run_id(), self._run_id)
1505        self.assertEqual(handler.settings.run_id, self._run_id)
1506        self.assertEqual(handler.settings.min_nodes, self._min_nodes)
1507        self.assertEqual(handler.settings.max_nodes, self._max_nodes)
1508
1509        if self._timeout is None:
1510            self.assertIsNotNone(handler.settings.timeout)
1511        else:
1512            self.assertIs(handler.settings.timeout, self._timeout)
1513
1514    def test_init_initializes_handler_if_timeout_is_not_specified(self) -> None:
1515        self._timeout = None
1516
1517        self.test_init_initializes_handler()
1518
1519    def test_init_initializes_handler_if_min_and_max_nodes_are_equal(self) -> None:
1520        self._min_nodes = 3
1521        self._max_nodes = 3
1522
1523        self.test_init_initializes_handler()
1524
1525    def test_init_raises_error_if_min_nodes_is_not_positive(self) -> None:
1526        for num in [0, -10]:
1527            with self.subTest(min_nodes=num):
1528                self._min_nodes = num
1529
1530                with self.assertRaisesRegex(
1531                    ValueError,
1532                    rf"^The minimum number of nodes \({num}\) must be greater than zero.$",
1533                ):
1534                    self._create_handler()
1535
1536    def test_init_raises_error_if_max_nodes_is_less_than_min(self) -> None:
1537        self._min_nodes = 3
1538        self._max_nodes = 2
1539
1540        with self.assertRaisesRegex(
1541            ValueError,
1542            rf"^The maximum number of nodes \({self._max_nodes}\) must be greater than or equal to "
1543            "the minimum number of nodes "
1544            rf"\({self._min_nodes}\).$",
1545        ):
1546            self._create_handler()
1547
1548
1549class CreateHandlerTest(TestCase):
1550    def setUp(self) -> None:
1551        self._store = DummyStore()
1552
1553        self._backend = DummyRendezvousBackend()
1554
1555        self._params = RendezvousParameters(
1556            backend=self._backend.name,
1557            endpoint="dummy_endpoint",
1558            run_id="dummy_run_id",
1559            min_nodes=3,
1560            max_nodes=6,
1561            join_timeout="50",
1562            last_call_timeout="60",
1563            close_timeout="70",
1564        )
1565
1566        self._expected_timeout = RendezvousTimeout(
1567            timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70)
1568        )
1569
1570    def test_create_handler_returns_handler(self) -> None:
1571        handler = create_handler(self._store, self._backend, self._params)
1572
1573        self.assertEqual(handler.get_backend(), self._backend.name)
1574        self.assertEqual(handler.get_run_id(), self._params.run_id)
1575        self.assertEqual(handler.settings.min_nodes, self._params.min_nodes)
1576        self.assertEqual(handler.settings.max_nodes, self._params.max_nodes)
1577        self.assertEqual(handler.settings.timeout.join, self._expected_timeout.join)
1578        self.assertEqual(
1579            handler.settings.timeout.last_call, self._expected_timeout.last_call
1580        )
1581        self.assertEqual(handler.settings.timeout.close, self._expected_timeout.close)
1582
1583    def test_create_handler_returns_handler_if_timeout_is_not_specified(self) -> None:
1584        del self._params.config["join_timeout"]
1585        del self._params.config["last_call_timeout"]
1586        del self._params.config["close_timeout"]
1587
1588        self._expected_timeout = RendezvousTimeout()
1589
1590        self.test_create_handler_returns_handler()
1591
1592    @patch("torch.distributed.elastic.events.record_rdzv_event")
1593    def test_create_handler_records_and_raises_exceptions(self, record_mock) -> None:
1594        with patch.object(DynamicRendezvousHandler, "from_backend") as from_mock:
1595            from_mock.side_effect = RendezvousError("test error")
1596            with self.assertRaises(RendezvousError):
1597                create_handler(self._store, self._backend, self._params)
1598                record_mock.assert_called_once()
1599
1600    def test_create_handler_rdzv_local_addr(self) -> None:
1601        params = RendezvousParameters(
1602            backend=self._backend.name,
1603            endpoint="dummy_endpoint",
1604            run_id="dummy_run_id",
1605            min_nodes=1,
1606            max_nodes=1,
1607            join_timeout="50",
1608            last_call_timeout="60",
1609            close_timeout="70",
1610            local_addr="127.0.0.2",
1611        )
1612        store = HashStore()
1613        handler = create_handler(store, self._backend, params)
1614        rdzv_info = handler.next_rendezvous()
1615        self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "127.0.0.2")
1616
1617
1618def _ignore_exception(exception_type: Exception, fn: Callable):
1619    try:
1620        fn()
1621    except exception_type as e:
1622        pass
1623
1624
1625def _wait_for(condition, timeout=10, interval=1, name=None):
1626    def _wait_while():
1627        while True:
1628            if condition():
1629                break
1630            else:
1631                time.sleep(interval)
1632
1633    wait_thread = threading.Thread(target=_wait_while, name=name)
1634    wait_thread.start()
1635    wait_thread.join(timeout=timeout)
1636
1637
1638class _CapturingThread(threading.Thread):
1639    def __init__(self, target=None, name=None, args=None, kwargs=None):
1640        if args is None:
1641            args = ()
1642        if kwargs is None:
1643            kwargs = {}
1644        threading.Thread.__init__(
1645            self, target=target, args=args, kwargs=kwargs, name=name
1646        )
1647        self._result = None
1648
1649    def run(self):
1650        if self._target is not None:
1651            self._result = self._target(*self._args, **self._kwargs)
1652
1653    def join(self, *args):
1654        threading.Thread.join(self, *args)
1655        return self._result
1656
1657
1658class IntegrationTest(TestCase):
1659    def setUp(self) -> None:
1660        self._store = HashStore()
1661        self._handlers = []
1662        self._backend = _InMemoryRendezvousBackend()
1663
1664    def tearDown(self) -> None:
1665        for handler in self._handlers:
1666            handler._stop_heartbeats()
1667
1668    def _create_handler(self, **kwargs) -> DynamicRendezvousHandler:
1669        params = {
1670            "backend": self._backend.name,
1671            "endpoint": "dummy_endpoint",
1672            "run_id": "dummy_run_id",
1673            "min_nodes": 2,
1674            "max_nodes": 2,
1675            "join_timeout": "5",
1676            "local_addr": f"127.0.0.{len(self._handlers)}",
1677        }
1678        params.update(**kwargs)
1679
1680        rzdv_params = RendezvousParameters(**params)
1681
1682        handler = create_handler(self._store, self._backend, rzdv_params)
1683        self._handlers.append(handler)
1684        return handler
1685
1686    def test_all_nodes_join_rendezvous(self) -> None:
1687        handler1 = self._create_handler(min_nodes=2, max_nodes=2)
1688        handler2 = self._create_handler(min_nodes=2, max_nodes=2)
1689
1690        handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
1691        handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
1692
1693        handler1_thread.start()
1694        handler2_thread.start()
1695
1696        rdzv_info1: RendezvousInfo = handler1_thread.join()
1697        rdzv_info2: RendezvousInfo = handler2_thread.join()
1698        self.assertEqual(rdzv_info1.store.underlying_store, self._store)
1699        self.assertEqual(rdzv_info2.store.underlying_store, self._store)
1700
1701        self.assertNotEqual(rdzv_info1.rank, rdzv_info2.rank)
1702
1703        self.assertEqual(rdzv_info1.world_size, 2)
1704        self.assertEqual(rdzv_info2.world_size, 2)
1705
1706    def test_redundancy_list(self) -> None:
1707        handler1 = self._create_handler(min_nodes=2, max_nodes=2)
1708        handler2 = self._create_handler(min_nodes=2, max_nodes=2)
1709        handler3 = self._create_handler(min_nodes=2, max_nodes=2)
1710
1711        handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
1712        handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
1713        handler3_thread = _CapturingThread(
1714            target=_ignore_exception,
1715            args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()),
1716        )
1717
1718        handler1_thread.start()
1719        handler2_thread.start()
1720
1721        # establish successful rendezvous
1722        handler1_thread.join()
1723        handler2_thread.join()
1724
1725        # expect to register in redundancy list
1726        handler3_thread.start()
1727
1728        # wait until the handler3 is registered in the redundancy list
1729        _wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list)
1730
1731        state_and_token = self._backend.get_state()
1732        state = pickle.loads(state_and_token[0])
1733        addresses = [node.addr for node in state.redundancy_list]
1734        self.assertListEqual(addresses, ["127.0.0.2"])
1735
1736    def test_redundancy_transition_to_wait_list_then_join_rendezvous(self) -> None:
1737        handler1 = self._create_handler(
1738            min_nodes=1,
1739            max_nodes=2,
1740        )
1741        handler2 = self._create_handler(
1742            min_nodes=1,
1743            max_nodes=2,
1744            keep_alive_interval=timedelta(seconds=1),
1745        )
1746        handler3 = self._create_handler(
1747            min_nodes=1,
1748            max_nodes=2,
1749        )
1750
1751        handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
1752        handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
1753
1754        handler3_thread = _CapturingThread(
1755            target=_ignore_exception,
1756            args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()),
1757        )
1758
1759        handler1_thread.start()
1760        handler2_thread.start()
1761
1762        # establish successful rendezvous
1763        handler1_thread.join()
1764        handler2_thread.join()
1765
1766        handler3_thread.start()
1767
1768        _wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list)
1769
1770        handler2._stop_heartbeats()
1771
1772        _wait_for(
1773            lambda: len(pickle.loads(self._backend.get_state()[0]).participants) == 1
1774        )
1775        _wait_for(
1776            lambda: len(pickle.loads(self._backend.get_state()[0]).wait_list) == 1
1777        )
1778
1779    def test_use_agent_store_is_true_by_default(self):
1780        handler = self._create_handler(
1781            min_nodes=1,
1782            max_nodes=2,
1783        )
1784
1785        self.assertTrue(handler.use_agent_store)
1786
1787    @patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"})
1788    def test_use_agent_store_is_disabled(self):
1789        handler = self._create_handler(
1790            min_nodes=1,
1791            max_nodes=2,
1792        )
1793
1794        self.assertFalse(handler.use_agent_store)
1795
1796    @patch.object(dist, "PrefixStore")
1797    def test_share_tcp_store_from_backend(self, prefix_store_class_mock):
1798        prefix_store = Mock(spec=dist.PrefixStore)
1799        prefix_store_class_mock.return_value = prefix_store
1800
1801        tcp_store = Mock(spec=dist.TCPStore)
1802        expected_addr = "expected_address"
1803        expected_port = 54321
1804        type(tcp_store).host = PropertyMock(return_value=expected_addr)
1805        type(tcp_store).port = PropertyMock(return_value=expected_port)
1806        # this will be injected
1807        self._store = tcp_store
1808
1809        handler1 = self._create_handler(min_nodes=2, max_nodes=2)
1810        handler2 = self._create_handler(min_nodes=2, max_nodes=2)
1811
1812        handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
1813        handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
1814
1815        handler1_thread.start()
1816        handler2_thread.start()
1817
1818        rdzv_info1: RendezvousInfo = handler1_thread.join()
1819        rdzv_info2: RendezvousInfo = handler2_thread.join()
1820
1821        self.assertEqual(rdzv_info1.store, prefix_store)
1822        self.assertEqual(rdzv_info2.store, prefix_store)
1823        prefix_store_class_mock.assert_called_with(
1824            "torch.rendezvous.dummy_run_id.0", tcp_store
1825        )
1826
1827        self.assertEqual(
1828            rdzv_info1.bootstrap_store_info, rdzv_info2.bootstrap_store_info
1829        )
1830
1831        self.assertEqual(rdzv_info1.bootstrap_store_info.master_addr, expected_addr)
1832        self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, expected_port)
1833
1834    @patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"})
1835    @patch.object(dist, "PrefixStore")
1836    def test_share_tcp_store_is_disabled(self, prefix_store_class_mock):
1837        prefix_store = Mock()
1838        prefix_store_class_mock.return_value = prefix_store
1839
1840        prefix_store.set.return_value = None
1841        prefix_store.get.return_value = b"123"
1842        tcp_store = Mock(spec=dist.TCPStore)
1843        # this will be injected
1844        self._store = tcp_store
1845
1846        handler1 = self._create_handler(min_nodes=2, max_nodes=2)
1847        handler2 = self._create_handler(min_nodes=2, max_nodes=2)
1848
1849        handler1_thread = _CapturingThread(target=handler1.next_rendezvous)
1850        handler2_thread = _CapturingThread(target=handler2.next_rendezvous)
1851
1852        handler1_thread.start()
1853        handler2_thread.start()
1854
1855        rdzv_info1: RendezvousInfo = handler1_thread.join()
1856        rdzv_info2: RendezvousInfo = handler2_thread.join()
1857
1858        self.assertEqual(rdzv_info1.store, prefix_store)
1859        self.assertEqual(rdzv_info2.store, prefix_store)
1860        prefix_store_class_mock.assert_called_with(
1861            "torch.rendezvous.dummy_run_id.0", self._store
1862        )
1863        self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, 123)
1864        self.assertEqual(rdzv_info2.bootstrap_store_info.master_port, 123)
1865
1866
1867class _InMemoryRendezvousBackend(RendezvousBackend):
1868    def __init__(self) -> None:
1869        self._lock = threading.Lock()
1870        self._state = None
1871        self._token = None
1872
1873    @property
1874    def name(self):
1875        return "_in_memory_backend"
1876
1877    def get_state(self):
1878        with self._lock:
1879            if self._state is None:
1880                return None
1881            return (self._state, self._token)
1882
1883        return self._state
1884
1885    def set_state(self, state, token):
1886        if state is None:
1887            raise ValueError("State cannot be None.")
1888        with self._lock:
1889            if token is None and self._token is not None:
1890                return None
1891            if self._token != token:
1892                return None
1893
1894            self._state = state
1895            self._token = self._token + 1 if self._token is not None else 0
1896