1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TensorIterator.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/TensorIteratorInternal.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/empty.h>
10 #endif
11
12 #include <c10/util/irange.h>
13
14 /// Contains the implementation of parallel reductions in TensorIterator.
15
16 namespace at {
17
18 using loop2d_t = TensorIteratorBase::loop2d_t;
19
20 static bool use_two_pass_reduction(TensorIteratorBase& iter);
21 static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop);
22 static void parallel_dim_reduction(TensorIteratorBase& iter, loop2d_t loop);
23
parallel_reduce(loop2d_t loop)24 void TensorIteratorBase::parallel_reduce(loop2d_t loop) {
25 TORCH_CHECK(ntensors() == 2, "parallel_reduce only supports one input and one output");
26 int64_t numel = this->numel();
27 if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
28 at::in_parallel_region()) {
29 serial_for_each(loop, {0, numel});
30 } else if (use_two_pass_reduction(*this)) {
31 two_pass_reduction(*this, loop);
32 } else {
33 parallel_dim_reduction(*this, loop);
34 }
35 }
36
use_two_pass_reduction(TensorIteratorBase & iter)37 static bool use_two_pass_reduction(TensorIteratorBase& iter) {
38 return iter.output(0).numel() == 1;
39 }
40
two_pass_reduction(TensorIteratorBase & iter,loop2d_t loop)41 static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop) {
42 const int max_threads = at::get_num_threads();
43
44 const auto& dst = iter.output(0);
45 auto unsqueezed = dst.unsqueeze(0);
46 auto buffer_shape = DimVector(unsqueezed.sizes());
47 buffer_shape[0] = max_threads;
48 auto buffer = at::empty(buffer_shape, dst.options());
49 // Fill with the identity
50 buffer.copy_(unsqueezed);
51
52 auto buffer_stride = buffer.strides()[0] * buffer.element_size();
53 auto buffer_0 = buffer[0];
54 auto first_reduce = TensorIterator::reduce_op(buffer_0, iter.input(0));
55 TORCH_INTERNAL_ASSERT(first_reduce.output(0).is_alias_of(buffer_0));
56
57 at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
58 const auto thread_num = at::get_thread_num();
59 auto shape = first_reduce.shape();
60 auto strides = first_reduce.get_strides();
61
62 // Bump output ptr so each thread has its own output slice
63 auto base_ptrs = first_reduce.get_base_ptrs();
64 base_ptrs[0] += buffer_stride * thread_num;
65
66 at::internal::serial_for_each(shape, strides, base_ptrs.data(),
67 base_ptrs.size(), loop, {begin, end});
68 });
69
70 auto final_reduce = TensorIterator::reduce_op(unsqueezed, buffer);
71 final_reduce.for_each(loop);
72 }
73
74 /// Chooses a dimension over which to parallelize. Prefers the outer-most
75 /// dimension thats larger than the number of available threads.
find_split_dim(TensorIteratorBase & iter)76 static int find_split_dim(TensorIteratorBase& iter) {
77 int num_threads = at::get_num_threads();
78 auto shape = iter.shape();
79
80 // start with the outer-most dimension
81 int best_dim = iter.ndim() - 1;
82 for (int dim = best_dim; dim >= 0 && !iter.is_dim_reduced(dim); dim--) {
83 if (shape[dim] >= num_threads) {
84 return dim;
85 } else if (shape[dim] > shape[best_dim]) {
86 best_dim = dim;
87 }
88 }
89
90 AT_ASSERT(!iter.is_dim_reduced(best_dim));
91 return best_dim;
92 }
93
94 static std::tuple<int64_t, int64_t>
round_columns(TensorIteratorBase & iter,int dim,int multiple,int64_t begin,int64_t end)95 round_columns(TensorIteratorBase& iter, int dim, int multiple, int64_t begin, int64_t end) {
96 begin = begin - (begin % multiple);
97 if (end != iter.shape()[dim]) {
98 // only round the 'end' column down if it's not the final column
99 end = end - (end % multiple);
100 }
101 return std::make_tuple(begin, end);
102 }
103
parallel_dim_reduction(TensorIteratorBase & iter,loop2d_t loop)104 static void parallel_dim_reduction(TensorIteratorBase& iter, loop2d_t loop) {
105 AT_ASSERT(iter.ndim() >= 1);
106 int dim = find_split_dim(iter);
107 int64_t cols = iter.shape()[dim];
108 int element_size = iter.element_size(/*arg=*/1);
109
110 bool should_round_columns = iter.strides(1)[dim] == element_size;
111 at::parallel_for(0, cols, 1, [&](int64_t begin, int64_t end) {
112 if (should_round_columns) {
113 // round columns to multiples of 128 bytes if adjacent columns are
114 // contiguous in memory.
115 int64_t cols_per_128_bytes = 128 / element_size;
116 std::tie(begin, end) = round_columns(iter, dim, cols_per_128_bytes, begin, end);
117 }
118 if (begin == end) {
119 return;
120 }
121 auto sub_iter = TensorIterator(iter);
122 sub_iter.narrow(dim, begin, end - begin);
123 sub_iter.for_each(loop);
124 });
125 }
126
foreach_reduced_elt(loop_subiter_t loop,bool parallelize)127 void TensorIteratorBase::foreach_reduced_elt(loop_subiter_t loop, bool parallelize) {
128 AT_ASSERT(ninputs() == 1);
129 AT_ASSERT(noutputs() >= 1);
130
131 auto shape = this->shape();
132 if (output(0).numel() == 0) {
133 return;
134 }
135 if (output(0).numel() == 1) {
136 loop(*this);
137 }
138 else if (numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
139 at::in_parallel_region() || !parallelize) {
140 auto reduce_dims = num_reduce_dims();
141
142 auto non_reduced_shape = shape.slice(reduce_dims, shape.size() - reduce_dims);
143
144 int64_t non_reduced_numel = 1;
145 for (const auto i : non_reduced_shape) {
146 non_reduced_numel *= i;
147 }
148 DimCounter dims {non_reduced_shape, {0, non_reduced_numel}};
149 while (!dims.is_done()) {
150 TensorIterator reduced = *this;
151 reduced.select_all_keeping_dim(reduce_dims, dims.values);
152 loop(reduced);
153 dims.increment({1, 1});
154 }
155 }
156 else {
157 int dim = find_split_dim(*this);
158 int64_t cols = shape[dim];
159 at::parallel_for(0, cols, 1, [&](int64_t begin, int64_t end) {
160 if (begin == end) {
161 return;
162 }
163
164 TensorIterator sub_iter(*this);
165
166 sub_iter.narrow(dim, begin, end - begin);
167 // On some broken setups, `#ifdef _OPENMP` is true,
168 // and `get_num_threads` returns > 1, but
169 // `#pragma omp parallel` is ignored.
170 // There is no API to check for this, so we need to explicitly
171 // stop trying to parallelize if we've already gotten here.
172 //
173 // (If we are on one of those broken setups, we will
174 // only have one thread here, and end - begin == cols.)
175 sub_iter.foreach_reduced_elt(loop, false);
176 });
177 }
178 }
179
180 } // namespace at
181