1 #pragma once 2 3 #include <torch/csrc/autograd/function.h> 4 #include <torch/csrc/distributed/autograd/context/context.h> 5 #include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h> 6 #include <torch/csrc/distributed/rpc/rpc_agent.h> 7 8 namespace torch { 9 namespace distributed { 10 namespace autograd { 11 12 // Forward declarations. 13 class DistAutogradContext; 14 15 // As part of our distributed autograd implementation, whenever we receive an 16 // RPC from a node, we add a 'RecvRpcBackward' autograd function to the 17 // autograd graph. This is more or less a placeholder function that is used to 18 // pass gradients to the remote host during the backward pass. The inputs to the 19 // RPC function are the inputs to this autograd function. 20 class TORCH_API RecvRpcBackward : public torch::autograd::Node { 21 public: 22 explicit RecvRpcBackward( 23 const AutogradMetadata& autogradMetadata, 24 std::shared_ptr<DistAutogradContext> autogradContext, 25 rpc::worker_id_t fromWorkerId, 26 rpc::DeviceMap deviceMap); 27 28 torch::autograd::variable_list apply( 29 torch::autograd::variable_list&& grads) override; 30 31 private: 32 const AutogradMetadata autogradMetadata_; 33 34 // Hold a weak reference to the autograd context to avoid circular 35 // dependencies with the context (since it holds a reference to 36 // RecvRpcBackward). 37 std::weak_ptr<DistAutogradContext> autogradContext_; 38 39 // The worker id from which the RPC was received. During the backward pass, 40 // we need to propagate the gradients to this workerId. 41 rpc::worker_id_t fromWorkerId_; 42 43 // Device mapping for tensors sent over RPC. 44 const rpc::DeviceMap deviceMap_; 45 }; 46 47 } // namespace autograd 48 } // namespace distributed 49 } // namespace torch 50