1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import os 10import shutil 11import tempfile 12import unittest 13from contextlib import closing 14 15import torch.distributed.launch as launch 16from torch.distributed.elastic.utils import get_socket_with_port 17from torch.testing._internal.common_utils import ( 18 skip_but_pass_in_sandcastle_if, 19 TEST_WITH_DEV_DBG_ASAN, 20) 21 22 23def path(script): 24 return os.path.join(os.path.dirname(__file__), script) 25 26 27class LaunchTest(unittest.TestCase): 28 def setUp(self): 29 self.test_dir = tempfile.mkdtemp() 30 # set a sentinel env var on the parent proc 31 # this should be present on the child and gets 32 # asserted in ``bin/test_script.py`` 33 os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR" 34 35 def tearDown(self): 36 shutil.rmtree(self.test_dir) 37 38 @skip_but_pass_in_sandcastle_if( 39 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 40 ) 41 def test_launch_without_env(self): 42 nnodes = 1 43 nproc_per_node = 4 44 world_size = nnodes * nproc_per_node 45 sock = get_socket_with_port() 46 with closing(sock): 47 master_port = sock.getsockname()[1] 48 args = [ 49 f"--nnodes={nnodes}", 50 f"--nproc-per-node={nproc_per_node}", 51 "--monitor-interval=1", 52 "--start-method=spawn", 53 "--master-addr=localhost", 54 f"--master-port={master_port}", 55 "--node-rank=0", 56 path("bin/test_script_local_rank.py"), 57 ] 58 launch.main(args) 59 60 @skip_but_pass_in_sandcastle_if( 61 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 62 ) 63 def test_launch_with_env(self): 64 nnodes = 1 65 nproc_per_node = 4 66 world_size = nnodes * nproc_per_node 67 sock = get_socket_with_port() 68 with closing(sock): 69 master_port = sock.getsockname()[1] 70 args = [ 71 f"--nnodes={nnodes}", 72 f"--nproc-per-node={nproc_per_node}", 73 "--monitor-interval=1", 74 "--start-method=spawn", 75 "--master-addr=localhost", 76 f"--master-port={master_port}", 77 "--node-rank=0", 78 "--use-env", 79 path("bin/test_script.py"), 80 f"--touch-file-dir={self.test_dir}", 81 ] 82 launch.main(args) 83 # make sure all the workers ran 84 # each worker touches a file with its global rank as the name 85 self.assertSetEqual( 86 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 87 ) 88