xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/request_callback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/request_callback.h>
2 
3 #include <torch/csrc/distributed/autograd/context/container.h>
4 #include <torch/csrc/distributed/autograd/utils.h>
5 
6 namespace torch::distributed::rpc {
7 
8 using namespace torch::distributed::autograd;
9 
operator ()(Message & request,std::vector<c10::Stream> streams) const10 c10::intrusive_ptr<JitFuture> RequestCallback::operator()(
11     Message& request,
12     std::vector<c10::Stream> streams) const {
13   // NB: cannot clear autograd context id here because the processMessage method
14   // might pause waiting for all RRefs in the arguments to be confirmed by their
15   // owners and resume processing in a different thread. Hence, the
16   // thread_local context id needs to be set and cleared in the thread that
17   // indeed carries out the processing logic.
18   return processMessage(request, std::move(streams));
19 }
20 
21 } // namespace torch::distributed::rpc
22