xref: /aosp_15_r20/external/pytorch/torch/distributed/optim/optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import logging
4from collections import defaultdict
5from threading import Lock
6from typing import List, Optional
7
8import torch
9import torch.distributed.autograd as dist_autograd
10import torch.distributed.rpc as rpc
11import torch.jit as jit
12import torch.nn as nn
13from torch import Tensor
14from torch.distributed.rpc import RRef
15
16from .utils import functional_optim_map
17
18
19__all__ = ["DistributedOptimizer"]
20
21logger = logging.getLogger(__name__)
22
23
24# XXX: we define a _ScriptModuleOptimizer here to explicitly
25# compile the FunctionalOptimizer class into TorchScript
26# This is because ScriptClass instance still lives in
27# python unless you explicitly compile it as an attribute
28# in ScriptModule or pass it to a ScriptFunction
29# _ScriptLocalOptimizerInterface serves as a common
30# interface type for Optimizer ScriptModules.
31#
32# TODO (wanchaol): remove this once we added TorchScript
33# class reference semantics
34@jit.interface
35class _ScriptLocalOptimizerInterface:
36    def step(self, autograd_ctx_id: int) -> None:
37        pass
38
39
40class _ScriptLocalOptimizer(nn.Module):
41    # TorchScript does not support multithread concurrent compiling.
42    # request_callback might invoke concurrent compiling, so we
43    # serialize the compiling with a lock
44    compile_lock = Lock()
45
46    def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
47        super().__init__()
48        self._local_params = [rref.local_value() for rref in local_params_rref]
49        self.optim = optim_cls(self._local_params, *args, **kwargs)
50
51    @jit.export
52    def step(self, autograd_ctx_id: int):
53        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
54        # apply functional optimizer step with a list of gradients
55        grads: List[Optional[Tensor]] = [
56            all_local_grads[p] if p in all_local_grads else None
57            for p in self._local_params
58        ]
59
60        self.optim.step(grads)
61
62
63# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once
64# we have converted all to functional optimizer in distributed.optim
65class _LocalOptimizer:
66    # Ideally we would only need to share a lock for instances of
67    # _LocalOptimizer that deal with the same parameters. We are
68    # making a simplifying assumption here that if there is more
69    # than one instance of _LocalOptimizer per worker, they will
70    # be optimizing the same parameters (e.g. each data parallel
71    # trainer will create its own instance of _LocalOptimizer but
72    # they will all optimize the same parameters on each worker)
73    global_lock = Lock()
74
75    def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
76        self._local_params = [rref.local_value() for rref in local_params_rref]
77        self.optim = optim_cls(self._local_params, *args, **kwargs)
78
79    def step(self, autograd_ctx_id):
80        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
81
82        with _LocalOptimizer.global_lock:
83            for param, grad in all_local_grads.items():
84                param.grad = grad
85            self.optim.step()
86
87
88def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
89    return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
90
91
92def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
93    local_optim = local_optim_rref.local_value()
94    local_optim.step(autograd_ctx_id)
95
96
97# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer
98def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
99    optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)
100
101    with _ScriptLocalOptimizer.compile_lock:
102        script_optim = jit.script(optim)
103        return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface)
104
105
106@jit.script
107def _script_local_optimizer_step(
108    local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int
109) -> None:
110    local_optim = local_optim_rref.local_value()
111    local_optim.step(autograd_ctx_id)
112
113
114def _wait_for_all(rpc_futs):
115    # TODO: improve error propagation
116    exception = None
117    results = []
118    for fut in rpc_futs:
119        try:
120            results.append(fut.wait())
121        except Exception as e:
122            results.append(e)
123            exception = e
124    if exception is not None:
125        raise exception
126    return results
127
128
129class DistributedOptimizer:
130    """
131    DistributedOptimizer takes remote references to parameters scattered
132    across workers and applies the given optimizer locally for each parameter.
133
134    This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
135    to retrieve the gradients for specific parameters.
136
137    Concurrent calls to
138    :meth:`~torch.distributed.optim.DistributedOptimizer.step`,
139    either from the same or different clients, will
140    be serialized on each worker -- as each worker's optimizer can only work
141    on one set of gradients at a time. However, there is no guarantee that
142    the full forward-backward-optimizer sequence will execute for one client
143    at a time. This means that the gradients being applied may not correspond
144    to the latest forward pass executed on a given worker. Also, there is no
145    guaranteed ordering across workers.
146
147    `DistributedOptimizer` creates the local optimizer with TorchScript enabled
148    by default, so that optimizer updates are not blocked by the Python Global
149    Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed
150    Model Parallel). This feature is currently enabled for most optimizers. You
151    can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support
152    for your own custom optimizers.
153
154    Args:
155        optimizer_class (optim.Optimizer): the class of optimizer to
156            instantiate on each worker.
157        params_rref (list[RRef]): list of RRefs to local or remote parameters
158            to optimize.
159        args: arguments to pass to the optimizer constructor on each worker.
160        kwargs: arguments to pass to the optimizer constructor on each worker.
161
162    Example::
163        >>> # xdoctest: +SKIP("distributed")
164        >>> import torch.distributed.autograd as dist_autograd
165        >>> import torch.distributed.rpc as rpc
166        >>> from torch import optim
167        >>> from torch.distributed.optim import DistributedOptimizer
168        >>>
169        >>> with dist_autograd.context() as context_id:
170        >>>   # Forward pass.
171        >>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
172        >>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
173        >>>   loss = rref1.to_here() + rref2.to_here()
174        >>>
175        >>>   # Backward pass.
176        >>>   dist_autograd.backward(context_id, [loss.sum()])
177        >>>
178        >>>   # Optimizer.
179        >>>   dist_optim = DistributedOptimizer(
180        >>>      optim.SGD,
181        >>>      [rref1, rref2],
182        >>>      lr=0.05,
183        >>>   )
184        >>>   dist_optim.step(context_id)
185
186    __ https://github.com/pytorch/tutorials/pull/1465
187    """
188
189    def __init__(self, optimizer_class, params_rref, *args, **kwargs):
190        torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer")
191        per_worker_params_rref = defaultdict(list)
192        for param in params_rref:
193            per_worker_params_rref[param.owner()].append(param)
194
195        if optimizer_class in functional_optim_map and jit._state._enabled:
196            optim_ctor = functional_optim_map.get(optimizer_class)
197        else:
198            optim_ctor = optimizer_class
199        self.is_functional_optim = optim_ctor != optimizer_class
200
201        if self.is_functional_optim:
202            optimizer_new_func = _new_script_local_optimizer
203        else:
204            logger.warning(
205                "Creating the optimizer %s without TorchScript support, "
206                "this might result in slow computation time in multithreading environment"
207                "(i.e. Distributed Model Parallel training on CPU) due to the Python's "
208                "Global Interpreter Lock (GIL). Please file an issue if you need this "
209                "optimizer in TorchScript. ",
210                optimizer_class,
211            )
212            optimizer_new_func = _new_local_optimizer
213
214        remote_optim_futs = []
215        for worker, param_rrefs in per_worker_params_rref.items():
216            remote_optim_rref_fut = rpc.rpc_async(
217                worker,
218                optimizer_new_func,
219                args=(optim_ctor, param_rrefs) + args,
220                kwargs=kwargs,
221            )
222            remote_optim_futs.append(remote_optim_rref_fut)
223
224        self.remote_optimizers = _wait_for_all(remote_optim_futs)
225
226    def step(self, context_id):
227        """
228        Performs a single optimization step.
229
230        This will call :meth:`torch.optim.Optimizer.step` on each worker
231        containing parameters to be optimized, and will block until all workers
232        return. The provided ``context_id`` will be used to retrieve the
233        corresponding :class:`~torch.distributed.autograd.context` that
234        contains the gradients that should be applied to the parameters.
235
236        Args:
237            context_id: the autograd context id for which we should run the
238                optimizer step.
239        """
240        dist_autograd._is_valid_context(context_id)
241
242        optimizer_step_func = (
243            _script_local_optimizer_step
244            if self.is_functional_optim
245            else _local_optimizer_step
246        )
247
248        rpc_futs = []
249        for optimizer in self.remote_optimizers:
250            rpc_futs.append(
251                rpc.rpc_async(
252                    optimizer.owner(),
253                    optimizer_step_func,
254                    args=(optimizer, context_id),
255                )
256            )
257        _wait_for_all(rpc_futs)
258