xref: /aosp_15_r20/external/pytorch/test/distributed/launcher/launch_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import os
10import shutil
11import tempfile
12import unittest
13from contextlib import closing
14
15import torch.distributed.launch as launch
16from torch.distributed.elastic.utils import get_socket_with_port
17from torch.testing._internal.common_utils import (
18    skip_but_pass_in_sandcastle_if,
19    TEST_WITH_DEV_DBG_ASAN,
20)
21
22
23def path(script):
24    return os.path.join(os.path.dirname(__file__), script)
25
26
27class LaunchTest(unittest.TestCase):
28    def setUp(self):
29        self.test_dir = tempfile.mkdtemp()
30        # set a sentinel env var on the parent proc
31        # this should be present on the child and gets
32        # asserted in ``bin/test_script.py``
33        os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR"
34
35    def tearDown(self):
36        shutil.rmtree(self.test_dir)
37
38    @skip_but_pass_in_sandcastle_if(
39        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
40    )
41    def test_launch_without_env(self):
42        nnodes = 1
43        nproc_per_node = 4
44        world_size = nnodes * nproc_per_node
45        sock = get_socket_with_port()
46        with closing(sock):
47            master_port = sock.getsockname()[1]
48        args = [
49            f"--nnodes={nnodes}",
50            f"--nproc-per-node={nproc_per_node}",
51            "--monitor-interval=1",
52            "--start-method=spawn",
53            "--master-addr=localhost",
54            f"--master-port={master_port}",
55            "--node-rank=0",
56            path("bin/test_script_local_rank.py"),
57        ]
58        launch.main(args)
59
60    @skip_but_pass_in_sandcastle_if(
61        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
62    )
63    def test_launch_with_env(self):
64        nnodes = 1
65        nproc_per_node = 4
66        world_size = nnodes * nproc_per_node
67        sock = get_socket_with_port()
68        with closing(sock):
69            master_port = sock.getsockname()[1]
70        args = [
71            f"--nnodes={nnodes}",
72            f"--nproc-per-node={nproc_per_node}",
73            "--monitor-interval=1",
74            "--start-method=spawn",
75            "--master-addr=localhost",
76            f"--master-port={master_port}",
77            "--node-rank=0",
78            "--use-env",
79            path("bin/test_script.py"),
80            f"--touch-file-dir={self.test_dir}",
81        ]
82        launch.main(args)
83        # make sure all the workers ran
84        # each worker touches a file with its global rank as the name
85        self.assertSetEqual(
86            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
87        )
88