1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import logging 4from collections import defaultdict 5from threading import Lock 6from typing import List, Optional 7 8import torch 9import torch.distributed.autograd as dist_autograd 10import torch.distributed.rpc as rpc 11import torch.jit as jit 12import torch.nn as nn 13from torch import Tensor 14from torch.distributed.rpc import RRef 15 16from .utils import functional_optim_map 17 18 19__all__ = ["DistributedOptimizer"] 20 21logger = logging.getLogger(__name__) 22 23 24# XXX: we define a _ScriptModuleOptimizer here to explicitly 25# compile the FunctionalOptimizer class into TorchScript 26# This is because ScriptClass instance still lives in 27# python unless you explicitly compile it as an attribute 28# in ScriptModule or pass it to a ScriptFunction 29# _ScriptLocalOptimizerInterface serves as a common 30# interface type for Optimizer ScriptModules. 31# 32# TODO (wanchaol): remove this once we added TorchScript 33# class reference semantics 34@jit.interface 35class _ScriptLocalOptimizerInterface: 36 def step(self, autograd_ctx_id: int) -> None: 37 pass 38 39 40class _ScriptLocalOptimizer(nn.Module): 41 # TorchScript does not support multithread concurrent compiling. 42 # request_callback might invoke concurrent compiling, so we 43 # serialize the compiling with a lock 44 compile_lock = Lock() 45 46 def __init__(self, optim_cls, local_params_rref, *args, **kwargs): 47 super().__init__() 48 self._local_params = [rref.local_value() for rref in local_params_rref] 49 self.optim = optim_cls(self._local_params, *args, **kwargs) 50 51 @jit.export 52 def step(self, autograd_ctx_id: int): 53 all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) 54 # apply functional optimizer step with a list of gradients 55 grads: List[Optional[Tensor]] = [ 56 all_local_grads[p] if p in all_local_grads else None 57 for p in self._local_params 58 ] 59 60 self.optim.step(grads) 61 62 63# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once 64# we have converted all to functional optimizer in distributed.optim 65class _LocalOptimizer: 66 # Ideally we would only need to share a lock for instances of 67 # _LocalOptimizer that deal with the same parameters. We are 68 # making a simplifying assumption here that if there is more 69 # than one instance of _LocalOptimizer per worker, they will 70 # be optimizing the same parameters (e.g. each data parallel 71 # trainer will create its own instance of _LocalOptimizer but 72 # they will all optimize the same parameters on each worker) 73 global_lock = Lock() 74 75 def __init__(self, optim_cls, local_params_rref, *args, **kwargs): 76 self._local_params = [rref.local_value() for rref in local_params_rref] 77 self.optim = optim_cls(self._local_params, *args, **kwargs) 78 79 def step(self, autograd_ctx_id): 80 all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) 81 82 with _LocalOptimizer.global_lock: 83 for param, grad in all_local_grads.items(): 84 param.grad = grad 85 self.optim.step() 86 87 88def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): 89 return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)) 90 91 92def _local_optimizer_step(local_optim_rref, autograd_ctx_id): 93 local_optim = local_optim_rref.local_value() 94 local_optim.step(autograd_ctx_id) 95 96 97# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer 98def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): 99 optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) 100 101 with _ScriptLocalOptimizer.compile_lock: 102 script_optim = jit.script(optim) 103 return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface) 104 105 106@jit.script 107def _script_local_optimizer_step( 108 local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int 109) -> None: 110 local_optim = local_optim_rref.local_value() 111 local_optim.step(autograd_ctx_id) 112 113 114def _wait_for_all(rpc_futs): 115 # TODO: improve error propagation 116 exception = None 117 results = [] 118 for fut in rpc_futs: 119 try: 120 results.append(fut.wait()) 121 except Exception as e: 122 results.append(e) 123 exception = e 124 if exception is not None: 125 raise exception 126 return results 127 128 129class DistributedOptimizer: 130 """ 131 DistributedOptimizer takes remote references to parameters scattered 132 across workers and applies the given optimizer locally for each parameter. 133 134 This class uses :meth:`~torch.distributed.autograd.get_gradients` in order 135 to retrieve the gradients for specific parameters. 136 137 Concurrent calls to 138 :meth:`~torch.distributed.optim.DistributedOptimizer.step`, 139 either from the same or different clients, will 140 be serialized on each worker -- as each worker's optimizer can only work 141 on one set of gradients at a time. However, there is no guarantee that 142 the full forward-backward-optimizer sequence will execute for one client 143 at a time. This means that the gradients being applied may not correspond 144 to the latest forward pass executed on a given worker. Also, there is no 145 guaranteed ordering across workers. 146 147 `DistributedOptimizer` creates the local optimizer with TorchScript enabled 148 by default, so that optimizer updates are not blocked by the Python Global 149 Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed 150 Model Parallel). This feature is currently enabled for most optimizers. You 151 can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support 152 for your own custom optimizers. 153 154 Args: 155 optimizer_class (optim.Optimizer): the class of optimizer to 156 instantiate on each worker. 157 params_rref (list[RRef]): list of RRefs to local or remote parameters 158 to optimize. 159 args: arguments to pass to the optimizer constructor on each worker. 160 kwargs: arguments to pass to the optimizer constructor on each worker. 161 162 Example:: 163 >>> # xdoctest: +SKIP("distributed") 164 >>> import torch.distributed.autograd as dist_autograd 165 >>> import torch.distributed.rpc as rpc 166 >>> from torch import optim 167 >>> from torch.distributed.optim import DistributedOptimizer 168 >>> 169 >>> with dist_autograd.context() as context_id: 170 >>> # Forward pass. 171 >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) 172 >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) 173 >>> loss = rref1.to_here() + rref2.to_here() 174 >>> 175 >>> # Backward pass. 176 >>> dist_autograd.backward(context_id, [loss.sum()]) 177 >>> 178 >>> # Optimizer. 179 >>> dist_optim = DistributedOptimizer( 180 >>> optim.SGD, 181 >>> [rref1, rref2], 182 >>> lr=0.05, 183 >>> ) 184 >>> dist_optim.step(context_id) 185 186 __ https://github.com/pytorch/tutorials/pull/1465 187 """ 188 189 def __init__(self, optimizer_class, params_rref, *args, **kwargs): 190 torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer") 191 per_worker_params_rref = defaultdict(list) 192 for param in params_rref: 193 per_worker_params_rref[param.owner()].append(param) 194 195 if optimizer_class in functional_optim_map and jit._state._enabled: 196 optim_ctor = functional_optim_map.get(optimizer_class) 197 else: 198 optim_ctor = optimizer_class 199 self.is_functional_optim = optim_ctor != optimizer_class 200 201 if self.is_functional_optim: 202 optimizer_new_func = _new_script_local_optimizer 203 else: 204 logger.warning( 205 "Creating the optimizer %s without TorchScript support, " 206 "this might result in slow computation time in multithreading environment" 207 "(i.e. Distributed Model Parallel training on CPU) due to the Python's " 208 "Global Interpreter Lock (GIL). Please file an issue if you need this " 209 "optimizer in TorchScript. ", 210 optimizer_class, 211 ) 212 optimizer_new_func = _new_local_optimizer 213 214 remote_optim_futs = [] 215 for worker, param_rrefs in per_worker_params_rref.items(): 216 remote_optim_rref_fut = rpc.rpc_async( 217 worker, 218 optimizer_new_func, 219 args=(optim_ctor, param_rrefs) + args, 220 kwargs=kwargs, 221 ) 222 remote_optim_futs.append(remote_optim_rref_fut) 223 224 self.remote_optimizers = _wait_for_all(remote_optim_futs) 225 226 def step(self, context_id): 227 """ 228 Performs a single optimization step. 229 230 This will call :meth:`torch.optim.Optimizer.step` on each worker 231 containing parameters to be optimized, and will block until all workers 232 return. The provided ``context_id`` will be used to retrieve the 233 corresponding :class:`~torch.distributed.autograd.context` that 234 contains the gradients that should be applied to the parameters. 235 236 Args: 237 context_id: the autograd context id for which we should run the 238 optimizer step. 239 """ 240 dist_autograd._is_valid_context(context_id) 241 242 optimizer_step_func = ( 243 _script_local_optimizer_step 244 if self.is_functional_optim 245 else _local_optimizer_step 246 ) 247 248 rpc_futs = [] 249 for optimizer in self.remote_optimizers: 250 rpc_futs.append( 251 rpc.rpc_async( 252 optimizer.owner(), 253 optimizer_step_func, 254 args=(optimizer, context_id), 255 ) 256 ) 257 _wait_for_all(rpc_futs) 258