xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 
5 namespace torch {
6 namespace distributed {
7 namespace autograd {
8 
9 using rpc::Message;
10 using rpc::MessageType;
11 
RRefBackwardReq(const rpc::RRefId & rrefId,int64_t autogradContextId,bool retainGraph)12 RRefBackwardReq::RRefBackwardReq(
13     const rpc::RRefId& rrefId,
14     int64_t autogradContextId,
15     bool retainGraph)
16     : rrefId_(rrefId),
17       autogradContextId_(autogradContextId),
18       retainGraph_(retainGraph) {}
19 
toMessageImpl()20 c10::intrusive_ptr<Message> RRefBackwardReq::toMessageImpl() && {
21   std::vector<at::IValue> ivalues;
22 
23   // Add all the fields.
24   ivalues.emplace_back(rrefId_.toIValue());
25   ivalues.emplace_back(autogradContextId_);
26   ivalues.emplace_back(retainGraph_);
27 
28   // Now pickle using JIT pickler.
29   std::vector<torch::Tensor> tensorTable;
30   std::vector<char> payload =
31       jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
32 
33   return c10::make_intrusive<Message>(
34       std::move(payload),
35       std::move(tensorTable),
36       MessageType::RREF_BACKWARD_REQ);
37 }
38 
fromMessage(const Message & message)39 std::unique_ptr<RRefBackwardReq> RRefBackwardReq::fromMessage(
40     const Message& message) {
41   // Unpickle the message and retrieve tupleElements.
42   auto payload = static_cast<const char*>(message.payload().data());
43   auto payload_size = message.payload().size();
44   IValue tuple = jit::unpickle(
45       payload,
46       payload_size,
47       *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
48       message.tensors());
49   const auto& tupleElements = std::move(*std::move(tuple).toTuple()).elements();
50 
51   // Build RRefBackwardReq.
52   TORCH_INTERNAL_ASSERT(tupleElements.size() == 3);
53 
54   // Retrieve all fields.
55   bool retainGraph = tupleElements[2].toBool();
56   int64_t autogradContextId = tupleElements[1].toInt();
57   rpc::RRefId rrefId = rpc::RRefId::fromIValue(tupleElements[0]);
58 
59   return std::make_unique<RRefBackwardReq>(
60       rrefId, autogradContextId, retainGraph);
61 }
62 
getRRefId() const63 const rpc::RRefId& RRefBackwardReq::getRRefId() const {
64   return rrefId_;
65 }
66 
getAutogradContextId() const67 int64_t RRefBackwardReq::getAutogradContextId() const {
68   return autogradContextId_;
69 }
70 
retainGraph() const71 bool RRefBackwardReq::retainGraph() const {
72   return retainGraph_;
73 }
74 
75 } // namespace autograd
76 } // namespace distributed
77 } // namespace torch
78