xref: /aosp_15_r20/external/pytorch/test/distributed/test_launcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5from contextlib import closing
6
7import torch.distributed as dist
8import torch.distributed.launch as launch
9from torch.distributed.elastic.utils import get_socket_with_port
10
11
12if not dist.is_available():
13    print("Distributed not available, skipping tests", file=sys.stderr)
14    sys.exit(0)
15
16from torch.testing._internal.common_utils import (
17    run_tests,
18    TEST_WITH_DEV_DBG_ASAN,
19    TestCase,
20)
21
22
23def path(script):
24    return os.path.join(os.path.dirname(__file__), script)
25
26
27if TEST_WITH_DEV_DBG_ASAN:
28    print(
29        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
30    )
31    sys.exit(0)
32
33
34class TestDistributedLaunch(TestCase):
35    def test_launch_user_script(self):
36        nnodes = 1
37        nproc_per_node = 4
38        world_size = nnodes * nproc_per_node
39        sock = get_socket_with_port()
40        with closing(sock):
41            master_port = sock.getsockname()[1]
42        args = [
43            f"--nnodes={nnodes}",
44            f"--nproc-per-node={nproc_per_node}",
45            "--monitor-interval=1",
46            "--start-method=spawn",
47            "--master-addr=localhost",
48            f"--master-port={master_port}",
49            "--node-rank=0",
50            "--use-env",
51            path("bin/test_script.py"),
52        ]
53        launch.main(args)
54
55
56if __name__ == "__main__":
57    run_tests()
58