1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import datetime 10import os 11import socket 12from contextlib import closing 13from typing import Optional 14 15import torch.distributed as dist 16from torch.distributed.elastic.utils.logging import get_logger 17from torch.distributed.elastic.utils.store import barrier 18 19 20__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] 21 22logger = get_logger(__name__) 23 24_ADDRESS_IN_USE = "Address already in use" 25_SOCKET_TIMEOUT = "Socket Timeout" 26 27_TCP_STORE_INIT = "_tcp_store/num_members" 28 29 30def create_c10d_store( 31 is_server: bool, 32 server_addr: str, 33 server_port: int = -1, 34 world_size: int = 1, 35 timeout: float = (60 * 10), # 10 min 36 wait_for_workers: bool = True, 37 retries=3, 38 use_libuv: Optional[bool] = None, 39): 40 if use_libuv is not None: 41 logger.warning( 42 "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " 43 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' 44 "is not set, libuv will be used by default." 45 ) 46 47 # check os.environ for use_libuv 48 use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option 49 50 if server_port == -1 and world_size > 1: 51 raise ValueError( 52 f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" 53 ) 54 55 if server_port != -1: 56 logger.info("sever_port: %s, specified, ignoring retries", server_port) 57 58 # only retry when server_port is NOT static 59 attempt = retries if server_port == -1 else 1 60 while True: 61 if server_port != -1: 62 port = server_port 63 else: 64 port = get_free_port() 65 66 logger.info( 67 "Creating c10d store on %s:%s\n" 68 " world_size : %s\n" 69 " is_server : %s\n" 70 " timeout(sec): %s\n" 71 " use_libuv : %s\n", 72 server_addr, 73 port, 74 world_size, 75 is_server, 76 timeout, 77 use_libuv, 78 ) 79 80 try: 81 store = dist.TCPStore( 82 host_name=server_addr, 83 port=port, 84 world_size=world_size, 85 is_master=is_server, 86 timeout=datetime.timedelta(seconds=timeout), 87 wait_for_workers=wait_for_workers, 88 use_libuv=use_libuv, 89 ) 90 # skips full rank check when we don't have to wait for all workers 91 if wait_for_workers: 92 _check_full_rank(store, world_size, timeout=timeout) 93 logger.info("Successfully created c10d store") 94 return store 95 except RuntimeError as e: 96 # this is brittle, but the underlying exception type is not properly pybinded 97 # so we parse the error msg for now, interestingly this is how torch itself 98 # detects timeouts and port conflicts in their own unittests 99 # see - caffe2/torch/testing/_internal/common_utils.py 100 # TODO properly map the exceptions in pybind (c10d/init.cpp) 101 if str(e) == _ADDRESS_IN_USE: # this will only happen on the server 102 if attempt < retries: 103 logger.warning( 104 "port: %s already in use, attempt: [%s/%s]", 105 port, 106 attempt, 107 retries, 108 ) 109 attempt += 1 110 else: 111 raise RuntimeError( 112 f"on {server_addr}, port: {port} already in use" 113 ) from e 114 else: 115 raise 116 117 118def _check_full_rank(store, world_size, timeout): 119 try: 120 barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout) 121 except RuntimeError as e: 122 if str(e) == _SOCKET_TIMEOUT: 123 raise TimeoutError( 124 f"timed out waiting for all {world_size} members to join" 125 ) from e 126 else: 127 raise 128 129 130def get_free_port(): 131 """ 132 Returns an unused port on localhost. 133 134 This function finds an unused port on localhost by opening to socket to bind 135 to a port and then closing it. 136 137 Returns: 138 int: an unused port on localhost 139 140 Example: 141 >>> # xdoctest: +SKIP("Nondeterministic") 142 >>> get_free_port() 143 63976 144 145 ..note: 146 The port returned by :func:`get_free_port` is not reserved and may be 147 taken by another process after this function returns. 148 """ 149 sock = get_socket_with_port() 150 with closing(sock): 151 return sock.getsockname()[1] 152 153 154def get_socket_with_port() -> socket.socket: 155 """ 156 Returns a free port on localhost that is "reserved" by binding a temporary 157 socket on it. Close the socket before passing the port to the entity 158 that requires it. Usage example 159 160 :: 161 162 sock = _get_socket_with_port() 163 with closing(sock): 164 port = sock.getsockname()[1] 165 sock.close() 166 # there is still a race-condition that some other process 167 # may grab this port before func() runs 168 func(port) 169 """ 170 171 addrs = socket.getaddrinfo( 172 host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM 173 ) 174 for addr in addrs: 175 family, type, proto, _, _ = addr 176 s = socket.socket(family, type, proto) 177 try: 178 s.bind(("localhost", 0)) 179 s.listen(0) 180 return s 181 except OSError as e: 182 s.close() 183 logger.warning("Socket creation attempt failed.", exc_info=e) 184 raise RuntimeError("Failed to create a socket") 185