xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/functions/recvrpc_backward.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/distributed/autograd/context/context.h>
5 #include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
6 #include <torch/csrc/distributed/rpc/rpc_agent.h>
7 
8 namespace torch {
9 namespace distributed {
10 namespace autograd {
11 
12 // Forward declarations.
13 class DistAutogradContext;
14 
15 // As part of our distributed autograd implementation, whenever we receive an
16 // RPC from a node, we add a 'RecvRpcBackward' autograd function to the
17 // autograd graph. This is more or less a placeholder function that is used to
18 // pass gradients to the remote host during the backward pass. The inputs to the
19 // RPC function are the inputs to this autograd function.
20 class TORCH_API RecvRpcBackward : public torch::autograd::Node {
21  public:
22   explicit RecvRpcBackward(
23       const AutogradMetadata& autogradMetadata,
24       std::shared_ptr<DistAutogradContext> autogradContext,
25       rpc::worker_id_t fromWorkerId,
26       rpc::DeviceMap deviceMap);
27 
28   torch::autograd::variable_list apply(
29       torch::autograd::variable_list&& grads) override;
30 
31  private:
32   const AutogradMetadata autogradMetadata_;
33 
34   // Hold a weak reference to the autograd context to avoid circular
35   // dependencies with the context (since it holds a reference to
36   // RecvRpcBackward).
37   std::weak_ptr<DistAutogradContext> autogradContext_;
38 
39   // The worker id from which the RPC was received. During the backward pass,
40   // we need to propagate the gradients to this workerId.
41   rpc::worker_id_t fromWorkerId_;
42 
43   // Device mapping for tensors sent over RPC.
44   const rpc::DeviceMap deviceMap_;
45 };
46 
47 } // namespace autograd
48 } // namespace distributed
49 } // namespace torch
50