xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/cpp_c10d_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include "cpp_c10d_extension.hpp"
2 
3 #include <map>
4 
5 namespace c10d {
6 
~WorkTest()7 ProcessGroupTest::WorkTest::~WorkTest() {}
8 
isCompleted()9 bool ProcessGroupTest::WorkTest::isCompleted() {
10   return true;
11 }
12 
isSuccess() const13 bool ProcessGroupTest::WorkTest::isSuccess() const {
14   return true;
15 }
16 
wait(std::chrono::milliseconds)17 bool ProcessGroupTest::WorkTest::wait(std::chrono::milliseconds /* unused */) {
18   return true;
19 }
20 
ProcessGroupTest(int rank,int size)21 ProcessGroupTest::ProcessGroupTest(int rank, int size)
22     : ProcessGroup(rank, size) {}
23 
~ProcessGroupTest()24 ProcessGroupTest::~ProcessGroupTest() {}
25 
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts)26 c10::intrusive_ptr<Work> ProcessGroupTest::broadcast(
27     std::vector<at::Tensor>& tensors,
28     const BroadcastOptions& opts) {
29   return c10::make_intrusive<ProcessGroupTest::WorkTest>();
30 }
31 
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts)32 c10::intrusive_ptr<Work> ProcessGroupTest::allreduce(
33     std::vector<at::Tensor>& tensors,
34     const AllreduceOptions& opts) {
35   return c10::make_intrusive<ProcessGroupTest::WorkTest>();
36 }
37 
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts)38 c10::intrusive_ptr<Work> ProcessGroupTest::allreduce_coalesced(
39       std::vector<at::Tensor>& tensors,
40       const AllreduceCoalescedOptions& opts) {
41   throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced");
42 }
43 
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)44 c10::intrusive_ptr<Work> ProcessGroupTest::reduce(
45     std::vector<at::Tensor>& tensors,
46     const ReduceOptions& opts) {
47   throw std::runtime_error("ProcessGroupTest does not support reduce");
48 }
49 
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)50 c10::intrusive_ptr<Work> ProcessGroupTest::allgather(
51     std::vector<std::vector<at::Tensor>>& outputTensors,
52     std::vector<at::Tensor>& inputTensors,
53     const AllgatherOptions& opts) {
54   throw std::runtime_error("ProcessGroupTest does not support allgather");
55 }
56 
_allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions & opts)57 c10::intrusive_ptr<Work> ProcessGroupTest::_allgather_base(
58     at::Tensor& outputBuffer,
59     at::Tensor& inputBuffer,
60     const AllgatherOptions& opts) {
61   throw std::runtime_error("ProcessGroupTest does not support _allgather_base");
62 }
63 
barrier(const BarrierOptions & opts)64 c10::intrusive_ptr<Work> ProcessGroupTest::barrier(
65     const BarrierOptions& opts) {
66   return c10::make_intrusive<ProcessGroupTest::WorkTest>();
67 }
68 
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)69 c10::intrusive_ptr<Work> ProcessGroupTest::gather(
70     std::vector<std::vector<at::Tensor>>& outputTensors,
71     std::vector<at::Tensor>& inputTensors,
72     const GatherOptions& opts) {
73   throw std::runtime_error("ProcessGroupTest does not support gather");
74 }
75 
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)76 c10::intrusive_ptr<Work> ProcessGroupTest::scatter(
77     std::vector<at::Tensor>& outputTensors,
78     std::vector<std::vector<at::Tensor>>& inputTensors,
79     const ScatterOptions& opts) {
80   throw std::runtime_error("ProcessGroupTest does not support scatter");
81 }
82 
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)83 c10::intrusive_ptr<Work> ProcessGroupTest::reduce_scatter(
84     std::vector<at::Tensor>& outputTensors,
85     std::vector<std::vector<at::Tensor>>& inputTensors,
86     const ReduceScatterOptions& opts) {
87   throw std::runtime_error("ProcessGroupTest does not support reduce_scatter");
88 }
89 
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)90 c10::intrusive_ptr<Work> ProcessGroupTest::send(
91     std::vector<at::Tensor>& tensors,
92     int dstRank,
93     int tag) {
94   throw std::runtime_error("ProcessGroupTest does not support send");
95 }
96 
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)97 c10::intrusive_ptr<Work> ProcessGroupTest::recv(
98     std::vector<at::Tensor>& tensors,
99     int srcRank,
100     int tag) {
101   throw std::runtime_error("ProcessGroupTest does not support recv");
102 }
103 
recvAnysource(std::vector<at::Tensor> & tensor,int tag)104 c10::intrusive_ptr<Work> ProcessGroupTest::recvAnysource(
105     std::vector<at::Tensor>& tensor,
106     int tag) {
107   throw std::runtime_error("ProcessGroupTest does not support recvAnysource");
108 }
109 
createProcessGroupTest(const c10::intrusive_ptr<::c10d::Store> & store,int rank,int size,const std::chrono::duration<float> & timeout)110 c10::intrusive_ptr<ProcessGroup> ProcessGroupTest::createProcessGroupTest(
111     const c10::intrusive_ptr<::c10d::Store>& store,
112     int rank,
113     int size,
114     const std::chrono::duration<float>& timeout) {
115   return c10::make_intrusive<ProcessGroupTest>(rank, size);
116 }
117 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)118 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
119   m.def("createProcessGroupTest", &ProcessGroupTest::createProcessGroupTest);
120 }
121 
122 } // namespace c10d
123