1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5from contextlib import closing 6 7import torch.distributed as dist 8import torch.distributed.launch as launch 9from torch.distributed.elastic.utils import get_socket_with_port 10 11 12if not dist.is_available(): 13 print("Distributed not available, skipping tests", file=sys.stderr) 14 sys.exit(0) 15 16from torch.testing._internal.common_utils import ( 17 run_tests, 18 TEST_WITH_DEV_DBG_ASAN, 19 TestCase, 20) 21 22 23def path(script): 24 return os.path.join(os.path.dirname(__file__), script) 25 26 27if TEST_WITH_DEV_DBG_ASAN: 28 print( 29 "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr 30 ) 31 sys.exit(0) 32 33 34class TestDistributedLaunch(TestCase): 35 def test_launch_user_script(self): 36 nnodes = 1 37 nproc_per_node = 4 38 world_size = nnodes * nproc_per_node 39 sock = get_socket_with_port() 40 with closing(sock): 41 master_port = sock.getsockname()[1] 42 args = [ 43 f"--nnodes={nnodes}", 44 f"--nproc-per-node={nproc_per_node}", 45 "--monitor-interval=1", 46 "--start-method=spawn", 47 "--master-addr=localhost", 48 f"--master-port={master_port}", 49 "--node-rank=0", 50 "--use-env", 51 path("bin/test_script.py"), 52 ] 53 launch.main(args) 54 55 56if __name__ == "__main__": 57 run_tests() 58