xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorIteratorReduce.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TensorIterator.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/TensorIteratorInternal.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/empty.h>
10 #endif
11 
12 #include <c10/util/irange.h>
13 
14 /// Contains the implementation of parallel reductions in TensorIterator.
15 
16 namespace at {
17 
18 using loop2d_t = TensorIteratorBase::loop2d_t;
19 
20 static bool use_two_pass_reduction(TensorIteratorBase& iter);
21 static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop);
22 static void parallel_dim_reduction(TensorIteratorBase& iter, loop2d_t loop);
23 
parallel_reduce(loop2d_t loop)24 void TensorIteratorBase::parallel_reduce(loop2d_t loop) {
25   TORCH_CHECK(ntensors() == 2, "parallel_reduce only supports one input and one output");
26   int64_t numel = this->numel();
27   if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
28       at::in_parallel_region()) {
29     serial_for_each(loop, {0, numel});
30   } else if (use_two_pass_reduction(*this)) {
31     two_pass_reduction(*this, loop);
32   } else {
33     parallel_dim_reduction(*this, loop);
34   }
35 }
36 
use_two_pass_reduction(TensorIteratorBase & iter)37 static bool use_two_pass_reduction(TensorIteratorBase& iter) {
38   return iter.output(0).numel() == 1;
39 }
40 
two_pass_reduction(TensorIteratorBase & iter,loop2d_t loop)41 static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop) {
42   const int max_threads = at::get_num_threads();
43 
44   const auto& dst = iter.output(0);
45   auto unsqueezed = dst.unsqueeze(0);
46   auto buffer_shape = DimVector(unsqueezed.sizes());
47   buffer_shape[0] = max_threads;
48   auto buffer = at::empty(buffer_shape, dst.options());
49   // Fill with the identity
50   buffer.copy_(unsqueezed);
51 
52   auto buffer_stride = buffer.strides()[0] * buffer.element_size();
53   auto buffer_0 = buffer[0];
54   auto first_reduce = TensorIterator::reduce_op(buffer_0, iter.input(0));
55   TORCH_INTERNAL_ASSERT(first_reduce.output(0).is_alias_of(buffer_0));
56 
57   at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
58     const auto thread_num = at::get_thread_num();
59     auto shape = first_reduce.shape();
60     auto strides = first_reduce.get_strides();
61 
62     // Bump output ptr so each thread has its own output slice
63     auto base_ptrs = first_reduce.get_base_ptrs();
64     base_ptrs[0] += buffer_stride * thread_num;
65 
66     at::internal::serial_for_each(shape, strides, base_ptrs.data(),
67                                   base_ptrs.size(), loop, {begin, end});
68   });
69 
70   auto final_reduce = TensorIterator::reduce_op(unsqueezed, buffer);
71   final_reduce.for_each(loop);
72 }
73 
74 /// Chooses a dimension over which to parallelize. Prefers the outer-most
75 /// dimension thats larger than the number of available threads.
find_split_dim(TensorIteratorBase & iter)76 static int find_split_dim(TensorIteratorBase& iter) {
77   int num_threads = at::get_num_threads();
78   auto shape = iter.shape();
79 
80   // start with the outer-most dimension
81   int best_dim = iter.ndim() - 1;
82   for (int dim = best_dim; dim >= 0 && !iter.is_dim_reduced(dim); dim--) {
83     if (shape[dim] >= num_threads) {
84       return dim;
85     } else if (shape[dim] > shape[best_dim]) {
86       best_dim = dim;
87     }
88   }
89 
90   AT_ASSERT(!iter.is_dim_reduced(best_dim));
91   return best_dim;
92 }
93 
94 static std::tuple<int64_t, int64_t>
round_columns(TensorIteratorBase & iter,int dim,int multiple,int64_t begin,int64_t end)95 round_columns(TensorIteratorBase& iter, int dim, int multiple, int64_t begin, int64_t end) {
96   begin = begin - (begin % multiple);
97   if (end != iter.shape()[dim]) {
98     // only round the 'end' column down if it's not the final column
99     end = end - (end % multiple);
100   }
101   return std::make_tuple(begin, end);
102 }
103 
parallel_dim_reduction(TensorIteratorBase & iter,loop2d_t loop)104 static void parallel_dim_reduction(TensorIteratorBase& iter, loop2d_t loop) {
105   AT_ASSERT(iter.ndim() >= 1);
106   int dim = find_split_dim(iter);
107   int64_t cols = iter.shape()[dim];
108   int element_size = iter.element_size(/*arg=*/1);
109 
110   bool should_round_columns = iter.strides(1)[dim] == element_size;
111   at::parallel_for(0, cols, 1, [&](int64_t begin, int64_t end) {
112     if (should_round_columns) {
113       // round columns to multiples of 128 bytes if adjacent columns are
114       // contiguous in memory.
115       int64_t cols_per_128_bytes = 128 / element_size;
116       std::tie(begin, end) = round_columns(iter, dim, cols_per_128_bytes, begin, end);
117     }
118     if (begin == end) {
119       return;
120     }
121     auto sub_iter = TensorIterator(iter);
122     sub_iter.narrow(dim, begin, end - begin);
123     sub_iter.for_each(loop);
124   });
125 }
126 
foreach_reduced_elt(loop_subiter_t loop,bool parallelize)127 void TensorIteratorBase::foreach_reduced_elt(loop_subiter_t loop, bool parallelize) {
128   AT_ASSERT(ninputs() == 1);
129   AT_ASSERT(noutputs() >= 1);
130 
131   auto shape = this->shape();
132   if (output(0).numel() == 0) {
133     return;
134   }
135   if (output(0).numel() == 1) {
136     loop(*this);
137   }
138   else if (numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
139       at::in_parallel_region() || !parallelize) {
140     auto reduce_dims = num_reduce_dims();
141 
142     auto non_reduced_shape = shape.slice(reduce_dims, shape.size() - reduce_dims);
143 
144     int64_t non_reduced_numel = 1;
145     for (const auto i : non_reduced_shape) {
146       non_reduced_numel *= i;
147     }
148     DimCounter dims {non_reduced_shape, {0, non_reduced_numel}};
149     while (!dims.is_done()) {
150       TensorIterator reduced = *this;
151       reduced.select_all_keeping_dim(reduce_dims, dims.values);
152       loop(reduced);
153       dims.increment({1, 1});
154     }
155   }
156   else {
157     int dim = find_split_dim(*this);
158     int64_t cols = shape[dim];
159     at::parallel_for(0, cols, 1, [&](int64_t begin, int64_t end) {
160       if (begin == end) {
161         return;
162       }
163 
164       TensorIterator sub_iter(*this);
165 
166       sub_iter.narrow(dim, begin, end - begin);
167       // On some broken setups, `#ifdef _OPENMP` is true,
168       // and `get_num_threads` returns > 1, but
169       // `#pragma omp parallel` is ignored.
170       // There is no API to check for this, so we need to explicitly
171       // stop trying to parallelize if we've already gotten here.
172       //
173       // (If we are on one of those broken setups, we will
174       //  only have one thread here, and end - begin == cols.)
175       sub_iter.foreach_reduced_elt(loop, false);
176     });
177   }
178 }
179 
180 }  // namespace at
181