1 #include "cpp_c10d_extension.hpp"
2
3 #include <map>
4
5 namespace c10d {
6
~WorkTest()7 ProcessGroupTest::WorkTest::~WorkTest() {}
8
isCompleted()9 bool ProcessGroupTest::WorkTest::isCompleted() {
10 return true;
11 }
12
isSuccess() const13 bool ProcessGroupTest::WorkTest::isSuccess() const {
14 return true;
15 }
16
wait(std::chrono::milliseconds)17 bool ProcessGroupTest::WorkTest::wait(std::chrono::milliseconds /* unused */) {
18 return true;
19 }
20
ProcessGroupTest(int rank,int size)21 ProcessGroupTest::ProcessGroupTest(int rank, int size)
22 : ProcessGroup(rank, size) {}
23
~ProcessGroupTest()24 ProcessGroupTest::~ProcessGroupTest() {}
25
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts)26 c10::intrusive_ptr<Work> ProcessGroupTest::broadcast(
27 std::vector<at::Tensor>& tensors,
28 const BroadcastOptions& opts) {
29 return c10::make_intrusive<ProcessGroupTest::WorkTest>();
30 }
31
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts)32 c10::intrusive_ptr<Work> ProcessGroupTest::allreduce(
33 std::vector<at::Tensor>& tensors,
34 const AllreduceOptions& opts) {
35 return c10::make_intrusive<ProcessGroupTest::WorkTest>();
36 }
37
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts)38 c10::intrusive_ptr<Work> ProcessGroupTest::allreduce_coalesced(
39 std::vector<at::Tensor>& tensors,
40 const AllreduceCoalescedOptions& opts) {
41 throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced");
42 }
43
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)44 c10::intrusive_ptr<Work> ProcessGroupTest::reduce(
45 std::vector<at::Tensor>& tensors,
46 const ReduceOptions& opts) {
47 throw std::runtime_error("ProcessGroupTest does not support reduce");
48 }
49
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)50 c10::intrusive_ptr<Work> ProcessGroupTest::allgather(
51 std::vector<std::vector<at::Tensor>>& outputTensors,
52 std::vector<at::Tensor>& inputTensors,
53 const AllgatherOptions& opts) {
54 throw std::runtime_error("ProcessGroupTest does not support allgather");
55 }
56
_allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions & opts)57 c10::intrusive_ptr<Work> ProcessGroupTest::_allgather_base(
58 at::Tensor& outputBuffer,
59 at::Tensor& inputBuffer,
60 const AllgatherOptions& opts) {
61 throw std::runtime_error("ProcessGroupTest does not support _allgather_base");
62 }
63
barrier(const BarrierOptions & opts)64 c10::intrusive_ptr<Work> ProcessGroupTest::barrier(
65 const BarrierOptions& opts) {
66 return c10::make_intrusive<ProcessGroupTest::WorkTest>();
67 }
68
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)69 c10::intrusive_ptr<Work> ProcessGroupTest::gather(
70 std::vector<std::vector<at::Tensor>>& outputTensors,
71 std::vector<at::Tensor>& inputTensors,
72 const GatherOptions& opts) {
73 throw std::runtime_error("ProcessGroupTest does not support gather");
74 }
75
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)76 c10::intrusive_ptr<Work> ProcessGroupTest::scatter(
77 std::vector<at::Tensor>& outputTensors,
78 std::vector<std::vector<at::Tensor>>& inputTensors,
79 const ScatterOptions& opts) {
80 throw std::runtime_error("ProcessGroupTest does not support scatter");
81 }
82
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)83 c10::intrusive_ptr<Work> ProcessGroupTest::reduce_scatter(
84 std::vector<at::Tensor>& outputTensors,
85 std::vector<std::vector<at::Tensor>>& inputTensors,
86 const ReduceScatterOptions& opts) {
87 throw std::runtime_error("ProcessGroupTest does not support reduce_scatter");
88 }
89
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)90 c10::intrusive_ptr<Work> ProcessGroupTest::send(
91 std::vector<at::Tensor>& tensors,
92 int dstRank,
93 int tag) {
94 throw std::runtime_error("ProcessGroupTest does not support send");
95 }
96
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)97 c10::intrusive_ptr<Work> ProcessGroupTest::recv(
98 std::vector<at::Tensor>& tensors,
99 int srcRank,
100 int tag) {
101 throw std::runtime_error("ProcessGroupTest does not support recv");
102 }
103
recvAnysource(std::vector<at::Tensor> & tensor,int tag)104 c10::intrusive_ptr<Work> ProcessGroupTest::recvAnysource(
105 std::vector<at::Tensor>& tensor,
106 int tag) {
107 throw std::runtime_error("ProcessGroupTest does not support recvAnysource");
108 }
109
createProcessGroupTest(const c10::intrusive_ptr<::c10d::Store> & store,int rank,int size,const std::chrono::duration<float> & timeout)110 c10::intrusive_ptr<ProcessGroup> ProcessGroupTest::createProcessGroupTest(
111 const c10::intrusive_ptr<::c10d::Store>& store,
112 int rank,
113 int size,
114 const std::chrono::duration<float>& timeout) {
115 return c10::make_intrusive<ProcessGroupTest>(rank, size);
116 }
117
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)118 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
119 m.def("createProcessGroupTest", &ProcessGroupTest::createProcessGroupTest);
120 }
121
122 } // namespace c10d
123