xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 
3 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
4 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
5 #include <torch/csrc/distributed/rpc/py_rref.h>
6 #include <torch/csrc/distributed/rpc/python_functions.h>
7 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
8 #include <torch/csrc/distributed/rpc/request_callback_impl.h>
9 #include <torch/csrc/distributed/rpc/rpc_agent.h>
10 #include <torch/csrc/distributed/rpc/rref_context.h>
11 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
12 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
13 #include <torch/csrc/distributed/rpc/types.h>
14 #include <torch/csrc/jit/python/pybind_utils.h>
15 #include <torch/csrc/utils/object_ptr.h>
16 #include <torch/csrc/utils/pybind.h>
17 #include <torch/types.h>
18 
19 #include <pybind11/chrono.h>
20 #include <pybind11/operators.h>
21 
22 namespace torch::distributed::rpc {
23 
24 namespace {
25 
26 constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
27 
28 template <typename T>
29 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
30 
rpc_init(PyObject * _unused,PyObject * noargs)31 PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
32   auto rpc_module =
33       THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
34   if (!rpc_module) {
35     throw python_error();
36   }
37 
38   auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
39   if (!torch_C_module) {
40     throw python_error();
41   }
42 
43   auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
44   auto m =
45       torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
46 
47   auto module = py::handle(m).cast<py::module>();
48 
49   auto rpcBackendOptions =
50       shared_ptr_class_<RpcBackendOptions>(
51           module,
52           "RpcBackendOptions",
53           R"(An abstract structure encapsulating the options passed into the RPC
54             backend. An instance of this class can be passed in to
55             :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
56             with specific configurations, such as the RPC timeout and
57             ``init_method`` to be used. )")
58           .def(py::init<>())
59           .def(
60               py::init<float, std::string>(),
61               py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
62               py::arg("init_method") = kDefaultInitMethod)
63           .def_readwrite(
64               "rpc_timeout",
65               &RpcBackendOptions::rpcTimeoutSeconds,
66               R"(A float indicating the timeout to use for all
67                 RPCs. If an RPC does not complete in this timeframe, it will
68                 complete with an exception indicating that it has timed out.)")
69           .def_readwrite(
70               "init_method",
71               &RpcBackendOptions::initMethod,
72               R"(URL specifying how to initialize the process group.
73                 Default is ``env://``)");
74 
75   // The following C++ constants need to be cast so they can be used from
76   // python.
77   module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds);
78   module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout);
79   module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
80 
81   auto workerInfo =
82       shared_ptr_class_<WorkerInfo>(
83           module,
84           "WorkerInfo",
85           R"(A structure that encapsulates information of a worker in the system.
86             Contains the name and ID of the worker. This class is not meant to
87             be constructed directly, rather, an instance can be retrieved
88             through :meth:`~torch.distributed.rpc.get_worker_info` and the
89             result can be passed in to functions such as
90             :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`,
91             :meth:`~torch.distributed.rpc.remote` to avoid copying a string on
92             every invocation.)")
93           .def(
94               py::init<std::string, worker_id_t>(),
95               py::arg("name"),
96               py::arg("id"))
97           .def_readonly(
98               "name", &WorkerInfo::name_, R"(The name of the worker.)")
99           .def_readonly(
100               "id",
101               &WorkerInfo::id_,
102               R"(Globally unique id to identify the worker.)")
103           .def("__eq__", &WorkerInfo::operator==, py::is_operator())
104           // pybind11 suggests the syntax  .def(hash(py::self)), with the
105           // unqualified "hash" function call. However the
106           // argument-dependent lookup for the function "hash" doesn't get
107           // triggered in this context because it conflicts with the struct
108           // c10::hash, so  we need to use the qualified name
109           // py::detail::hash, which unfortunately is in a detail namespace.
110           .def(py::detail::hash(py::self)) // NOLINT
111           .def(
112               "__repr__",
113               [](const WorkerInfo& workerInfo) {
114                 std::ostringstream os;
115                 os << workerInfo;
116                 return os.str();
117               })
118           .def(py::pickle(
119               /* __getstate__ */
120               [](const WorkerInfo& workerInfo) {
121                 return py::make_tuple(workerInfo.name_, workerInfo.id_);
122               },
123               /* __setstate__ */
124               [](py::tuple t) {
125                 TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state.");
126 
127                 WorkerInfo info(
128                     t[0].cast<std::string>(), t[1].cast<worker_id_t>());
129                 return info;
130               }));
131 
132   auto rpcAgent =
133       shared_ptr_class_<RpcAgent>(module, "RpcAgent")
134           .def(
135               "join",
136               &RpcAgent::join,
137               py::call_guard<py::gil_scoped_release>(),
138               py::arg("shutdown") = false,
139               py::arg("timeout") = 0)
140           .def(
141               "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
142           .def(
143               "shutdown",
144               &RpcAgent::shutdown,
145               py::call_guard<py::gil_scoped_release>())
146           .def(
147               "get_worker_info",
148               (const WorkerInfo& (RpcAgent::*)(void) const) &
149                   RpcAgent::getWorkerInfo,
150               py::call_guard<py::gil_scoped_release>())
151           .def(
152               "get_worker_info",
153               (const WorkerInfo& (RpcAgent::*)(const std::string&) const) &
154                   RpcAgent::getWorkerInfo,
155               py::call_guard<py::gil_scoped_release>())
156           .def(
157               "get_worker_infos",
158               &RpcAgent::getWorkerInfos,
159               py::call_guard<py::gil_scoped_release>())
160           .def(
161               "_get_device_map",
162               &RpcAgent::getDeviceMap,
163               py::call_guard<py::gil_scoped_release>())
164           .def(
165               "get_debug_info",
166               &RpcAgent::getDebugInfo,
167               py::call_guard<py::gil_scoped_release>())
168           .def(
169               "get_metrics",
170               &RpcAgent::getMetrics,
171               py::call_guard<py::gil_scoped_release>());
172 
173   auto pyRRef =
174       shared_ptr_class_<PyRRef>(module, "PyRRef", R"(
175           A class encapsulating a reference to a value of some type on a remote
176           worker. This handle will keep the referenced remote value alive on the
177           worker. A ``UserRRef`` will be deleted when 1) no references to it in
178           both the application code and in the local RRef context, or 2) the
179           application has called a graceful shutdown. Invoking methods on a
180           deleted RRef leads to undefined behaviors. RRef implementation only
181           offers best-effort error detection, and applications should not use
182           ``UserRRefs`` after ``rpc.shutdown()``.
183 
184           .. warning::
185               RRefs can only be serialized and deserialized by the RPC module.
186               Serializing and deserializing RRefs without RPC (e.g., Python
187               pickle, torch :meth:`~torch.save` / :meth:`~torch.load`,
188               JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will
189               lead to errors.
190 
191           Args:
192               value (object): The value to be wrapped by this RRef.
193               type_hint (Type, optional): Python type that should be passed to
194                   ``TorchScript`` compiler as type hint for ``value``.
195 
196           Example::
197               Following examples skip RPC initialization and shutdown code
198               for simplicity. Refer to RPC docs for those details.
199 
200               1. Create an RRef using rpc.remote
201 
202               >>> import torch
203               >>> import torch.distributed.rpc as rpc
204               >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
205               >>> # get a copy of value from the RRef
206               >>> x = rref.to_here()
207 
208               2. Create an RRef from a local object
209 
210               >>> import torch
211               >>> from torch.distributed.rpc import RRef
212               >>> x = torch.zeros(2, 2)
213               >>> rref = RRef(x)
214 
215               3. Share an RRef with other workers
216 
217               >>> # On both worker0 and worker1:
218               >>> def f(rref):
219               >>>   return rref.to_here() + 1
220 
221               >>> # On worker0:
222               >>> import torch
223               >>> import torch.distributed.rpc as rpc
224               >>> from torch.distributed.rpc import RRef
225               >>> rref = RRef(torch.zeros(2, 2))
226               >>> # the following RPC shares the rref with worker1, reference
227               >>> # count is automatically updated.
228               >>> rpc.rpc_sync("worker1", f, args=(rref,))
229           )")
230           .def(
231               py::init<const py::object&, const py::object&>(),
232               py::arg("value"),
233               py::arg("type_hint") = py::none())
234           .def(
235               // not releasing GIL here to avoid context switch on getters
236               "is_owner",
237               &PyRRef::isOwner,
238               R"(
239                   Returns whether or not the current node is the owner of this
240                   ``RRef``.
241               )")
242           .def(
243               "confirmed_by_owner",
244               &PyRRef::confirmedByOwner,
245               R"(
246                   Returns whether this ``RRef`` has been confirmed by the owner.
247                   ``OwnerRRef`` always returns true, while ``UserRRef`` only
248                   returns true when the owner knowns about this ``UserRRef``.
249               )")
250           .def(
251               // not releasing GIL here to avoid context switch on getters
252               "owner",
253               &PyRRef::owner,
254               R"(
255                   Returns worker information of the node that owns this ``RRef``.
256               )")
257           .def(
258               // not releasing GIL here to avoid context switch on getters
259               "owner_name",
260               &PyRRef::ownerName,
261               R"(
262                   Returns worker name of the node that owns this ``RRef``.
263               )")
264           .def(
265               "to_here",
266               &PyRRef::toHere,
267               py::arg("timeout") = py::cast(kUnsetRpcTimeout),
268               py::call_guard<py::gil_scoped_release>(),
269               R"(
270                   Blocking call that copies the value of the RRef from the owner
271                   to the local node and returns it. If the current node is the
272                   owner, returns a reference to the local value.
273 
274                   Args:
275                       timeout (float, optional): Timeout for ``to_here``. If
276                           the call does not complete within this timeframe, an
277                           exception indicating so will be raised. If this
278                           argument is not provided, the default RPC timeout
279                           (60s) will be used.
280               )")
281           .def(
282               "local_value",
283               &PyRRef::localValue,
284               py::call_guard<py::gil_scoped_release>(),
285               R"(
286                   If the current node is the owner, returns a reference to the
287                   local value. Otherwise, throws an exception.
288               )")
289           .def(
290               "rpc_sync",
291               [](const PyRRef& self, float timeoutSeconds) {
292                 return self.createRRefProxy(
293                     RRefProxyType::RPC_SYNC, timeoutSeconds);
294               },
295               py::arg("timeout") = kUnsetRpcTimeout,
296               py::call_guard<py::gil_scoped_release>(),
297               R"(
298                   Create a helper proxy to easily launch an ``rpc_sync`` using
299                   the owner of the RRef as the destination to run functions on
300                   the object referenced by this RRef. More specifically,
301                   ``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as
302                   the following:
303 
304                   >>> def run(rref, func_name, args, kwargs):
305                   >>>   return getattr(rref.local_value(), func_name)(*args, **kwargs)
306                   >>>
307                   >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
308 
309                   Args:
310                       timeout (float, optional): Timeout for ``rref.rpc_sync()``.
311                           If the call does not complete within this timeframe, an
312                           exception indicating so will be raised. If this argument
313                           is not provided, the default RPC timeout will be used.
314 
315                   Example::
316                       >>> from torch.distributed import rpc
317                       >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
318                       >>> rref.rpc_sync().size()  # returns torch.Size([2, 2])
319                       >>> rref.rpc_sync().view(1, 4)  # returns tensor([[1., 1., 1., 1.]])
320               )")
321           .def(
322               "rpc_async",
323               [](const PyRRef& self, float timeoutSeconds) {
324                 return self.createRRefProxy(
325                     RRefProxyType::RPC_ASYNC, timeoutSeconds);
326               },
327               py::arg("timeout") = kUnsetRpcTimeout,
328               py::call_guard<py::gil_scoped_release>(),
329               R"(
330                   Create a helper proxy to easily launch an ``rpc_async`` using
331                   the owner of the RRef as the destination to run functions on
332                   the object referenced by this RRef. More specifically,
333                   ``rref.rpc_async().func_name(*args, **kwargs)`` is the same as
334                   the following:
335 
336                   >>> def run(rref, func_name, args, kwargs):
337                   >>>   return getattr(rref.local_value(), func_name)(*args, **kwargs)
338                   >>>
339                   >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
340 
341                   Args:
342                       timeout (float, optional): Timeout for ``rref.rpc_async()``.
343                           If the call does not complete within this timeframe, an
344                           exception indicating so will be raised. If this argument
345                           is not provided, the default RPC timeout will be used.
346 
347                   Example::
348                       >>> from torch.distributed import rpc
349                       >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
350                       >>> rref.rpc_async().size().wait()  # returns torch.Size([2, 2])
351                       >>> rref.rpc_async().view(1, 4).wait()  # returns tensor([[1., 1., 1., 1.]])
352               )")
353           .def(
354               "remote",
355               [](const PyRRef& self, float timeoutSeconds) {
356                 return self.createRRefProxy(
357                     RRefProxyType::REMOTE, timeoutSeconds);
358               },
359               py::arg("timeout") = kUnsetRpcTimeout,
360               py::call_guard<py::gil_scoped_release>(),
361               R"(
362                   Create a helper proxy to easily launch a ``remote`` using
363                   the owner of the RRef as the destination to run functions on
364                   the object referenced by this RRef. More specifically,
365                   ``rref.remote().func_name(*args, **kwargs)`` is the same as
366                   the following:
367 
368                   >>> def run(rref, func_name, args, kwargs):
369                   >>>   return getattr(rref.local_value(), func_name)(*args, **kwargs)
370                   >>>
371                   >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
372 
373                   Args:
374                       timeout (float, optional): Timeout for ``rref.remote()``. If
375                           the creation of this :class:`~torch.distributed.rpc.RRef`
376                           is not successfully completed within the timeout, then the
377                           next time there is an attempt to use the RRef
378                           (such as ``to_here``), a timeout will be raised. If not
379                           provided, the default RPC timeout will be used. Please see
380                           ``rpc.remote()`` for specific timeout semantics for
381                           :class:`~torch.distributed.rpc.RRef`.
382 
383                   Example::
384                       >>> from torch.distributed import rpc
385                       >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
386                       >>> rref.remote().size().to_here()  # returns torch.Size([2, 2])
387                       >>> rref.remote().view(1, 4).to_here()  # returns tensor([[1., 1., 1., 1.]])
388               )")
389           .def(
390               py::pickle(
391                   /* __getstate__ */
392                   [](const PyRRef& /* unused */) {
393                     TORCH_CHECK(
394                         false,
395                         "Can not pickle rref in python pickler, rref can only be "
396                         "pickled when using RPC");
397                     // Note that this return has no meaning since we always
398                     // throw, it's only here to satisfy Pybind API's
399                     // requirement.
400                     return py::make_tuple();
401                   },
402                   /* __setstate__ */
403                   [](py::tuple /* unused */) { // NOLINT
404                     TORCH_CHECK(
405                         false,
406                         "Can not unpickle rref in python pickler, rref can only be "
407                         "unpickled when using RPC");
408                     // Note that this return has no meaning since we always
409                     // throw, it's only here to satisfy PyBind's API
410                     // requirement.
411                     return PyRRef(
412                         py::cast<py::none>(Py_None),
413                         py::cast<py::none>(Py_None));
414                   }),
415               py::call_guard<py::gil_scoped_release>())
416           .def(
417               "_serialize",
418               &PyRRef::pickle,
419               py::call_guard<py::gil_scoped_release>())
420           .def_static(
421               "_deserialize",
422               &PyRRef::unpickle,
423               py::call_guard<py::gil_scoped_release>())
424           .def(
425               "_get_type",
426               // Intentionally not releasing GIL, as most accesses just
427               // retrieve cached type py::object
428               &PyRRef::getRRefType,
429               py::arg("timeout") = kUnsetRpcTimeout,
430               py::arg("blocking") = true,
431               R"(
432                   If ``blocking=True``, returns the type of the data object
433                   referenced by this ``RRef``. On the owner, this is same as
434                   ``type(rref.local_value())``. Otherwise, returns a future to
435                   this result. On a user, this will trigger an RPC to fetch the
436                   ``type`` object from the owner. After this function is run
437                   once, the ``type`` object is cached by the ``RRef``, and
438                   subsequent invocations no longer trigger RPC. Note that this is
439                   true regardless of the ``blocking`` argument of subsequent
440                   calls.
441 
442                   Args:
443                     rref (torch.distributed.rpc.RRef): The RRef to get type of.
444                     timeout (float, optional): Timeout, in seconds for
445                           ``_get_type``. If the call does not complete within
446                           this timeframe, an exception indicating so will be
447                           raised. If this argument is not provided, the default
448                           RPC timeout will be used.
449                     blocking (bool, optional): Whether to synchronously wait on
450                           the RPC triggered by the first call and return the
451                           type. If ``False``, will return a future. Default is
452                           ``True``.
453               )")
454           .def(
455               "_get_future",
456               [](const PyRRef& self) {
457                 return std::make_shared<jit::PythonFutureWrapper>(
458                     self.getFuture());
459               },
460               py::call_guard<py::gil_scoped_release>(),
461               R"(
462                   Returns the future that corresponds to the creation of this RRef
463                   on the remote node. This is for internal use cases such as profiling
464                   only.
465               )")
466           .def(
467               "_get_profiling_future",
468               [](const PyRRef& self) {
469                 return std::make_shared<jit::PythonFutureWrapper>(
470                     self.getProfilingFuture());
471               },
472               py::call_guard<py::gil_scoped_acquire>(),
473               R"(
474                   Returns future that completes when the profiling event corresponding
475                   to the creation of this RRef on the remote node has been recorded.
476               )")
477           .def(
478               "_set_profiling_future",
479               [](PyRRef& self,
480                  const std::shared_ptr<jit::PythonFutureWrapper>&
481                      wrappedFuture) {
482                 self.setProfilingFuture(wrappedFuture->fut);
483               },
484               py::call_guard<py::gil_scoped_acquire>(),
485               R"(
486                   Set future that is completed when the profiling event corresponding
487                   to the creation of this RRef on the remote node has been recorded.
488               )")
489           .def(
490               "backward",
491               [](PyRRef& self,
492                  int64_t dist_autograd_ctx_id,
493                  bool retain_graph) {
494                 self.backward(dist_autograd_ctx_id, retain_graph);
495               },
496               py::arg("dist_autograd_ctx_id") = -1,
497               py::arg("retain_graph") = false,
498               py::call_guard<py::gil_scoped_release>(),
499               R"(
500                   Runs the backward pass using the RRef as the root of the
501                   backward pass. If ``dist_autograd_ctx_id`` is provided,
502                   we perform a distributed backward pass using the provided
503                   ctx_id starting from the owner of the RRef. In this case,
504                   :meth:`~torch.distributed.autograd.get_gradients` should be
505                   used to retrieve the gradients. If ``dist_autograd_ctx_id``
506                   is ``None``, it is assumed that this is a local autograd graph
507                   and we only perform a local backward pass. In the local case,
508                   the node calling this API has to be the owner of the RRef.
509                   The value of the RRef is expected to be a scalar Tensor.
510 
511                 Args:
512                     dist_autograd_ctx_id (int, optional): The distributed
513                         autograd context id for which we should retrieve the
514                         gradients (default: -1).
515                     retain_graph(bool, optional): If ``False``, the graph used to
516                         compute the grad will be freed. Note that in nearly all
517                         cases setting this option to ``True`` is not needed and
518                         often can be worked around in a much more efficient way.
519                         Usually, you need to set this to ``True`` to run backward
520                         multiple times (default: False).
521 
522                 Example::
523                     >>> import torch.distributed.autograd as dist_autograd
524                     >>> with dist_autograd.context() as context_id:
525                     >>>     rref.backward(context_id)
526                 )")
527           // not releasing GIL to avoid context switch
528           .def("__repr__", &PyRRef::str);
529 
530 #ifdef USE_TENSORPIPE
531 
532   // Base class: torch.distributed.rpc.RpcBackendOptions.
533   py::class_<TensorPipeRpcBackendOptions>(
534       module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions)
535       .def(
536           py::init<
537               int,
538               std::optional<std::vector<std::string>>,
539               std::optional<std::vector<std::string>>,
540               float,
541               std::string,
542               std::unordered_map<std::string, DeviceMap>,
543               std::vector<c10::Device>>(),
544           py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
545           py::arg("_transports") = std::optional<std::vector<std::string>>(),
546           py::arg("_channels") = std::optional<std::vector<std::string>>(),
547           py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
548           py::arg("init_method") = kDefaultInitMethod,
549           py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(),
550           py::arg("devices") = std::vector<c10::Device>())
551       .def_readwrite(
552           "num_worker_threads",
553           &TensorPipeRpcBackendOptions::numWorkerThreads,
554           R"(
555               The number of threads in the thread-pool used by
556               :class:`~torch.distributed.rpc.TensorPipeAgent` to execute
557               requests.
558           )")
559       .def_readwrite(
560           "device_maps",
561           &TensorPipeRpcBackendOptions::deviceMaps,
562           R"(The device map locations.)")
563       .def_readwrite(
564           "devices",
565           &TensorPipeRpcBackendOptions::devices,
566           R"(All devices used by the local agent.)")
567       .def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap);
568 
569   module.attr("_DEFAULT_NUM_WORKER_THREADS") =
570       py::cast(kDefaultNumWorkerThreads);
571 
572   shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
573       .def(
574           py::init(
575               [](const c10::intrusive_ptr<::c10d::Store>& store,
576                  std::string selfName,
577                  worker_id_t selfId,
578                  std::optional<int> worldSize,
579                  TensorPipeRpcBackendOptions opts,
580                  std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
581                  std::vector<c10::Device> devices) {
582                 return std::shared_ptr<TensorPipeAgent>(
583                     new TensorPipeAgent(
584                         store,
585                         std::move(selfName),
586                         selfId,
587                         worldSize,
588                         std::move(opts),
589                         std::move(reverseDeviceMaps),
590                         std::move(devices),
591                         std::make_unique<RequestCallbackImpl>()),
592                     impl::destroy_without_gil<TensorPipeAgent>);
593               }),
594           py::arg("store"),
595           py::arg("name"),
596           py::arg("rank"),
597           py::arg("world_size"),
598           py::arg("rpc_backend_options"),
599           py::arg("reverse_device_maps"),
600           py::arg("devices"))
601       .def(
602           "join",
603           &TensorPipeAgent::join,
604           py::call_guard<py::gil_scoped_release>(),
605           py::arg("shutdown") = false,
606           py::arg("timeout") = 0)
607       .def(
608           "shutdown",
609           &TensorPipeAgent::shutdown,
610           py::call_guard<py::gil_scoped_release>())
611       .def(
612           "get_worker_info",
613           (const WorkerInfo& (TensorPipeAgent::*)(void) const) &
614               RpcAgent::getWorkerInfo,
615           py::call_guard<py::gil_scoped_release>())
616       .def(
617           "get_worker_info",
618           (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
619               TensorPipeAgent::getWorkerInfo,
620           py::call_guard<py::gil_scoped_release>())
621       .def(
622           "get_worker_info",
623           (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
624               TensorPipeAgent::getWorkerInfo,
625           py::call_guard<py::gil_scoped_release>())
626       .def(
627           "get_worker_infos",
628           (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
629               TensorPipeAgent::getWorkerInfos,
630           py::call_guard<py::gil_scoped_release>())
631       .def(
632           "_get_device_map",
633           (DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) &
634               TensorPipeAgent::getDeviceMap,
635           py::call_guard<py::gil_scoped_release>())
636       .def(
637           "_get_backend_options",
638           &TensorPipeAgent::getBackendOptions,
639           py::call_guard<py::gil_scoped_release>())
640       .def(
641           "_update_group_membership",
642           &TensorPipeAgent::updateGroupMembership,
643           py::call_guard<py::gil_scoped_release>())
644       .def_readonly("is_static_group", &TensorPipeAgent::isStaticGroup_)
645       .def_property_readonly("store", &TensorPipeAgent::getStore);
646 
647 #endif // USE_TENSORPIPE
648 
649   module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet);
650 
651   module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent);
652 
653   module.def(
654       "_set_and_start_rpc_agent",
655       [](const std::shared_ptr<RpcAgent>& rpcAgent) {
656         RpcAgent::setCurrentRpcAgent(rpcAgent);
657         // Initializing typeResolver inside RpcAgent constructor will make
658         // RpcAgent have python dependency. To avoid RpcAgent to have python
659         // dependency, setTypeResolver() here.
660         std::shared_ptr<TypeResolver> typeResolver =
661             std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
662               auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
663                   qn.qualifiedName());
664               return c10::StrongTypePtr(
665                   PythonRpcHandler::getInstance().jitCompilationUnit(),
666                   std::move(typePtr));
667             });
668         rpcAgent->setTypeResolver(typeResolver);
669         rpcAgent->start();
670       },
671       py::call_guard<py::gil_scoped_release>());
672 
673   module.def(
674       "_reset_current_rpc_agent",
675       []() { RpcAgent::setCurrentRpcAgent(nullptr); },
676       py::call_guard<py::gil_scoped_release>());
677 
678   module.def(
679       "_delete_all_user_and_unforked_owner_rrefs",
680       [](std::chrono::milliseconds timeoutMillis) {
681         RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
682       },
683       py::arg("timeout") = kDeleteAllUsersTimeout,
684       py::call_guard<py::gil_scoped_release>());
685 
686   module.def("_destroy_rref_context", [](bool ignoreRRefLeak) {
687     // NB: do not release GIL in the function. The destroyInstance() method
688     // returns a list of deleted OwnerRRefs that hold py::object instances.
689     // Clearing those OwnerRRefs are likely to trigger Python deref, which
690     // requires GIL.
691     RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear();
692   });
693 
694   module.def("_rref_context_get_debug_info", []() {
695     return RRefContext::getInstance().getDebugInfo();
696   });
697 
698   module.def(
699       "_cleanup_python_rpc_handler",
700       []() { PythonRpcHandler::getInstance().cleanup(); },
701       py::call_guard<py::gil_scoped_release>());
702 
703   module.def(
704       "_invoke_rpc_builtin",
705       [](const WorkerInfo& dst,
706          const std::string& opName,
707          const float rpcTimeoutSeconds,
708          const py::args& args,
709          const py::kwargs& kwargs) {
710         return std::make_shared<jit::PythonFutureWrapper>(
711             pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds));
712       },
713       py::call_guard<py::gil_scoped_acquire>());
714 
715   module.def(
716       "_invoke_rpc_python_udf",
717       [](const WorkerInfo& dst,
718          std::string& pickledPythonUDF,
719          std::vector<torch::Tensor>& tensors,
720          const float rpcTimeoutSeconds,
721          const bool isAsyncExecution) {
722         return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
723             dst,
724             pickledPythonUDF,
725             tensors,
726             rpcTimeoutSeconds,
727             isAsyncExecution));
728       },
729       py::call_guard<py::gil_scoped_release>());
730 
731   module.def(
732       "_invoke_rpc_torchscript",
733       [](const std::string& dstWorkerName,
734          const std::string& qualifiedNameStr,
735          const py::tuple& argsTuple,
736          const py::dict& kwargsDict,
737          const float rpcTimeoutSeconds,
738          const bool isAsyncExecution) {
739         return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript(
740             dstWorkerName,
741             qualifiedNameStr,
742             argsTuple,
743             kwargsDict,
744             rpcTimeoutSeconds,
745             isAsyncExecution));
746       },
747       py::call_guard<py::gil_scoped_release>());
748 
749   module.def(
750       "_invoke_remote_builtin",
751       &pyRemoteBuiltin,
752       py::call_guard<py::gil_scoped_acquire>());
753 
754   module.def(
755       "_invoke_remote_python_udf",
756       &pyRemotePythonUdf,
757       py::call_guard<py::gil_scoped_release>());
758 
759   module.def(
760       "_invoke_remote_torchscript",
761       &pyRemoteTorchscript,
762       py::call_guard<py::gil_scoped_release>());
763 
764   module.def(
765       "get_rpc_timeout",
766       []() {
767         return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() /
768             kSecToMsConversion;
769       },
770       R"(
771           Retrieve the default timeout for all RPCs that was set during RPC initialization.
772           The returned value will be in seconds.
773           Returns:
774             ``float`` indicating the RPC timeout in seconds.
775       )");
776 
777   module.def(
778       "enable_gil_profiling",
779       [](bool flag) {
780         RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag);
781       },
782       R"(
783     Set whether GIL wait times should be enabled or not. This incurs a slight
784     overhead cost. Default is disabled for performance reasons.
785 
786     Args:
787         flag (bool): True to set GIL profiling, False to disable.
788       )");
789 
790   module.def(
791       "_set_rpc_timeout",
792       [](const float rpcTimeoutSeconds) {
793         auto rpcTimeout = std::chrono::milliseconds(
794             static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
795         RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout);
796       },
797       R"(
798           Set the default timeout for all RPCs. The input unit is expected to be
799           in seconds. If an RPC is not completed within this time, an exception
800           indicating it has timed out will be raised. To control timeout for
801           specific RPCs, a timeout parameter can be passed into
802           :meth:`~torch.distributed.rpc.rpc_sync` and
803           :meth:`~torch.distributed.rpc.rpc_async`.
804 
805           Args:
806             rpcTimeoutSeconds (float): Timeout value in seconds.
807       )");
808 
809   module.def(
810       "_enable_server_process_global_profiler",
811       &profiler::processglobal::enableServer);
812   module.def(
813       "_disable_server_process_global_profiler",
814       &profiler::processglobal::disableServer);
815 
816   module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId);
817 
818   py::class_<
819       RemoteProfilerManager,
820       std::unique_ptr<RemoteProfilerManager, py::nodelete>>(
821       module, "RemoteProfilerManager")
822       .def("set_current_profiling_key", [](const std::string& key) {
823         auto& inst = RemoteProfilerManager::getInstance();
824         inst.setCurrentKey(key);
825       });
826 
827   module.def(
828       "_enable_jit_rref_pickle",
829       &enableJitRRefPickle,
830       R"(
831         Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with
832         pickled RRefs out of RPC contexts.
833 
834 
835         .. warning::
836             This is dangerous. If the module contains RRefs, the pickled
837             result must be sent over RPC and get unpickled on the receiving side
838             to restore the module. Otherwise, there will be RRef leaks, which
839             can potentially lead to program hang. When using this API, it is
840             applications responsibility to make sure that the above assumption
841             always holds.
842       )");
843   module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
844 
845   Py_RETURN_TRUE;
846 }
847 
848 } // namespace
849 
850 static PyMethodDef methods[] = { // NOLINT
851     {"_rpc_init", rpc_init, METH_NOARGS, nullptr},
852     {nullptr, nullptr, 0, nullptr}};
853 
python_functions()854 PyMethodDef* python_functions() {
855   return methods;
856 }
857 
858 } // namespace torch::distributed::rpc
859