1 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4
5 namespace torch {
6 namespace distributed {
7 namespace autograd {
8
9 using rpc::Message;
10 using rpc::MessageType;
11
RRefBackwardReq(const rpc::RRefId & rrefId,int64_t autogradContextId,bool retainGraph)12 RRefBackwardReq::RRefBackwardReq(
13 const rpc::RRefId& rrefId,
14 int64_t autogradContextId,
15 bool retainGraph)
16 : rrefId_(rrefId),
17 autogradContextId_(autogradContextId),
18 retainGraph_(retainGraph) {}
19
toMessageImpl()20 c10::intrusive_ptr<Message> RRefBackwardReq::toMessageImpl() && {
21 std::vector<at::IValue> ivalues;
22
23 // Add all the fields.
24 ivalues.emplace_back(rrefId_.toIValue());
25 ivalues.emplace_back(autogradContextId_);
26 ivalues.emplace_back(retainGraph_);
27
28 // Now pickle using JIT pickler.
29 std::vector<torch::Tensor> tensorTable;
30 std::vector<char> payload =
31 jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
32
33 return c10::make_intrusive<Message>(
34 std::move(payload),
35 std::move(tensorTable),
36 MessageType::RREF_BACKWARD_REQ);
37 }
38
fromMessage(const Message & message)39 std::unique_ptr<RRefBackwardReq> RRefBackwardReq::fromMessage(
40 const Message& message) {
41 // Unpickle the message and retrieve tupleElements.
42 auto payload = static_cast<const char*>(message.payload().data());
43 auto payload_size = message.payload().size();
44 IValue tuple = jit::unpickle(
45 payload,
46 payload_size,
47 *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
48 message.tensors());
49 const auto& tupleElements = std::move(*std::move(tuple).toTuple()).elements();
50
51 // Build RRefBackwardReq.
52 TORCH_INTERNAL_ASSERT(tupleElements.size() == 3);
53
54 // Retrieve all fields.
55 bool retainGraph = tupleElements[2].toBool();
56 int64_t autogradContextId = tupleElements[1].toInt();
57 rpc::RRefId rrefId = rpc::RRefId::fromIValue(tupleElements[0]);
58
59 return std::make_unique<RRefBackwardReq>(
60 rrefId, autogradContextId, retainGraph);
61 }
62
getRRefId() const63 const rpc::RRefId& RRefBackwardReq::getRRefId() const {
64 return rrefId_;
65 }
66
getAutogradContextId() const67 int64_t RRefBackwardReq::getAutogradContextId() const {
68 return autogradContextId_;
69 }
70
retainGraph() const71 bool RRefBackwardReq::retainGraph() const {
72 return retainGraph_;
73 }
74
75 } // namespace autograd
76 } // namespace distributed
77 } // namespace torch
78