xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/pooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/pooling.h>
2 
3 #include <torch/expanding_array.h>
4 
5 namespace F = torch::nn::functional;
6 
7 namespace torch {
8 namespace nn {
9 
10 template <size_t D, typename Derived>
AvgPoolImpl(const AvgPoolOptions<D> & options_)11 AvgPoolImpl<D, Derived>::AvgPoolImpl(const AvgPoolOptions<D>& options_)
12     : options(options_) {}
13 
14 template <size_t D, typename Derived>
reset()15 void AvgPoolImpl<D, Derived>::reset() {}
16 
17 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const18 void AvgPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
19   stream << "torch::nn::AvgPool" << D << "d"
20          << "(kernel_size=" << options.kernel_size()
21          << ", stride=" << options.stride() << ", padding=" << options.padding()
22          << ")";
23 }
24 
forward(const Tensor & input)25 Tensor AvgPool1dImpl::forward(const Tensor& input) {
26   return F::detail::avg_pool1d(
27       input,
28       options.kernel_size(),
29       options.stride(),
30       options.padding(),
31       options.ceil_mode(),
32       options.count_include_pad());
33 }
34 
forward(const Tensor & input)35 Tensor AvgPool2dImpl::forward(const Tensor& input) {
36   return F::detail::avg_pool2d(
37       input,
38       options.kernel_size(),
39       options.stride(),
40       options.padding(),
41       options.ceil_mode(),
42       options.count_include_pad(),
43       options.divisor_override());
44 }
45 
forward(const Tensor & input)46 Tensor AvgPool3dImpl::forward(const Tensor& input) {
47   return F::detail::avg_pool3d(
48       input,
49       options.kernel_size(),
50       options.stride(),
51       options.padding(),
52       options.ceil_mode(),
53       options.count_include_pad(),
54       options.divisor_override());
55 }
56 
57 template class AvgPoolImpl<1, AvgPool1dImpl>;
58 template class AvgPoolImpl<2, AvgPool2dImpl>;
59 template class AvgPoolImpl<3, AvgPool3dImpl>;
60 
61 // ============================================================================
62 
63 template <size_t D, typename Derived>
MaxPoolImpl(const MaxPoolOptions<D> & options_)64 MaxPoolImpl<D, Derived>::MaxPoolImpl(const MaxPoolOptions<D>& options_)
65     : options(options_) {}
66 
67 template <size_t D, typename Derived>
reset()68 void MaxPoolImpl<D, Derived>::reset() {}
69 
70 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const71 void MaxPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
72   stream << std::boolalpha << "torch::nn::MaxPool" << D << "d"
73          << "(kernel_size=" << options.kernel_size()
74          << ", stride=" << options.stride() << ", padding=" << options.padding()
75          << ", dilation=" << options.dilation()
76          << ", ceil_mode=" << options.ceil_mode() << ")";
77 }
78 
forward(const Tensor & input)79 Tensor MaxPool1dImpl::forward(const Tensor& input) {
80   return F::detail::max_pool1d(
81       input,
82       options.kernel_size(),
83       options.stride(),
84       options.padding(),
85       options.dilation(),
86       options.ceil_mode());
87 }
88 
forward_with_indices(const Tensor & input)89 std::tuple<Tensor, Tensor> MaxPool1dImpl::forward_with_indices(
90     const Tensor& input) {
91   return F::detail::max_pool1d_with_indices(
92       input,
93       options.kernel_size(),
94       options.stride(),
95       options.padding(),
96       options.dilation(),
97       options.ceil_mode());
98 }
99 
forward(const Tensor & input)100 Tensor MaxPool2dImpl::forward(const Tensor& input) {
101   return F::detail::max_pool2d(
102       input,
103       options.kernel_size(),
104       options.stride(),
105       options.padding(),
106       options.dilation(),
107       options.ceil_mode());
108 }
109 
forward_with_indices(const Tensor & input)110 std::tuple<Tensor, Tensor> MaxPool2dImpl::forward_with_indices(
111     const Tensor& input) {
112   return F::detail::max_pool2d_with_indices(
113       input,
114       options.kernel_size(),
115       options.stride(),
116       options.padding(),
117       options.dilation(),
118       options.ceil_mode());
119 }
120 
forward(const Tensor & input)121 Tensor MaxPool3dImpl::forward(const Tensor& input) {
122   return F::detail::max_pool3d(
123       input,
124       options.kernel_size(),
125       options.stride(),
126       options.padding(),
127       options.dilation(),
128       options.ceil_mode());
129 }
130 
forward_with_indices(const Tensor & input)131 std::tuple<Tensor, Tensor> MaxPool3dImpl::forward_with_indices(
132     const Tensor& input) {
133   return F::detail::max_pool3d_with_indices(
134       input,
135       options.kernel_size(),
136       options.stride(),
137       options.padding(),
138       options.dilation(),
139       options.ceil_mode());
140 }
141 
142 template class MaxPoolImpl<1, MaxPool1dImpl>;
143 template class MaxPoolImpl<2, MaxPool2dImpl>;
144 template class MaxPoolImpl<3, MaxPool3dImpl>;
145 
146 // ============================================================================
147 
forward(const Tensor & input)148 Tensor AdaptiveMaxPool1dImpl::forward(const Tensor& input) {
149   return F::detail::adaptive_max_pool1d(input, options.output_size());
150 }
151 
forward_with_indices(const Tensor & input)152 std::tuple<Tensor, Tensor> AdaptiveMaxPool1dImpl::forward_with_indices(
153     const Tensor& input) {
154   return F::detail::adaptive_max_pool1d_with_indices(
155       input, options.output_size());
156 }
157 
forward(const Tensor & input)158 Tensor AdaptiveMaxPool2dImpl::forward(const Tensor& input) {
159   return F::detail::adaptive_max_pool2d(input, options.output_size());
160 }
161 
forward_with_indices(const Tensor & input)162 std::tuple<Tensor, Tensor> AdaptiveMaxPool2dImpl::forward_with_indices(
163     const Tensor& input) {
164   return F::detail::adaptive_max_pool2d_with_indices(
165       input, options.output_size());
166 }
167 
forward(const Tensor & input)168 Tensor AdaptiveMaxPool3dImpl::forward(const Tensor& input) {
169   return F::detail::adaptive_max_pool3d(input, options.output_size());
170 }
171 
forward_with_indices(const Tensor & input)172 std::tuple<Tensor, Tensor> AdaptiveMaxPool3dImpl::forward_with_indices(
173     const Tensor& input) {
174   return F::detail::adaptive_max_pool3d_with_indices(
175       input, options.output_size());
176 }
177 
178 template class AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>;
179 template class AdaptiveMaxPoolImpl<
180     2,
181     ExpandingArrayWithOptionalElem<2>,
182     AdaptiveMaxPool2dImpl>;
183 template class AdaptiveMaxPoolImpl<
184     3,
185     ExpandingArrayWithOptionalElem<3>,
186     AdaptiveMaxPool3dImpl>;
187 
188 // ============================================================================
189 
forward(const Tensor & input)190 Tensor AdaptiveAvgPool1dImpl::forward(const Tensor& input) {
191   return F::detail::adaptive_avg_pool1d(input, options.output_size());
192 }
193 
forward(const Tensor & input)194 Tensor AdaptiveAvgPool2dImpl::forward(const Tensor& input) {
195   return F::detail::adaptive_avg_pool2d(input, options.output_size());
196 }
197 
forward(const Tensor & input)198 Tensor AdaptiveAvgPool3dImpl::forward(const Tensor& input) {
199   return F::detail::adaptive_avg_pool3d(input, options.output_size());
200 }
201 
202 template class AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>;
203 template class AdaptiveAvgPoolImpl<
204     2,
205     ExpandingArrayWithOptionalElem<2>,
206     AdaptiveAvgPool2dImpl>;
207 template class AdaptiveAvgPoolImpl<
208     3,
209     ExpandingArrayWithOptionalElem<3>,
210     AdaptiveAvgPool3dImpl>;
211 
212 // ============================================================================
213 
214 template <size_t D, typename Derived>
MaxUnpoolImpl(const MaxUnpoolOptions<D> & options_)215 MaxUnpoolImpl<D, Derived>::MaxUnpoolImpl(const MaxUnpoolOptions<D>& options_)
216     : options(options_) {}
217 
218 template <size_t D, typename Derived>
reset()219 void MaxUnpoolImpl<D, Derived>::reset() {}
220 
221 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const222 void MaxUnpoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
223   stream << std::boolalpha << "torch::nn::MaxUnpool" << D << "d"
224          << "(kernel_size=" << options.kernel_size()
225          << ", stride=" << options.stride() << ", padding=" << options.padding()
226          << ")";
227 }
228 
forward(const Tensor & input,const Tensor & indices,const std::optional<std::vector<int64_t>> & output_size)229 Tensor MaxUnpool1dImpl::forward(
230     const Tensor& input,
231     const Tensor& indices,
232     const std::optional<std::vector<int64_t>>& output_size) {
233   return F::detail::max_unpool1d(
234       input,
235       indices,
236       options.kernel_size(),
237       options.stride(),
238       options.padding(),
239       output_size);
240 }
241 
forward(const Tensor & input,const Tensor & indices,const std::optional<std::vector<int64_t>> & output_size)242 Tensor MaxUnpool2dImpl::forward(
243     const Tensor& input,
244     const Tensor& indices,
245     const std::optional<std::vector<int64_t>>& output_size) {
246   return F::detail::max_unpool2d(
247       input,
248       indices,
249       options.kernel_size(),
250       options.stride(),
251       options.padding(),
252       output_size);
253 }
254 
forward(const Tensor & input,const Tensor & indices,const std::optional<std::vector<int64_t>> & output_size)255 Tensor MaxUnpool3dImpl::forward(
256     const Tensor& input,
257     const Tensor& indices,
258     const std::optional<std::vector<int64_t>>& output_size) {
259   return F::detail::max_unpool3d(
260       input,
261       indices,
262       options.kernel_size(),
263       options.stride(),
264       options.padding(),
265       output_size);
266 }
267 
268 template class MaxUnpoolImpl<1, MaxUnpool1dImpl>;
269 template class MaxUnpoolImpl<2, MaxUnpool2dImpl>;
270 template class MaxUnpoolImpl<3, MaxUnpool3dImpl>;
271 
272 // ============================================================================
273 
FractionalMaxPool2dImpl(FractionalMaxPool2dOptions options_)274 FractionalMaxPool2dImpl::FractionalMaxPool2dImpl(
275     FractionalMaxPool2dOptions options_)
276     : options(std::move(options_)) {
277   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
278   reset();
279 }
280 
reset()281 void FractionalMaxPool2dImpl::reset() {
282   _random_samples =
283       register_buffer("_random_samples", options._random_samples());
284   if (options.output_size() == std::nullopt &&
285       options.output_ratio() == std::nullopt) {
286     TORCH_CHECK(
287         false,
288         "FractionalMaxPool2d requires specifying either ",
289         "an output size, or a pooling ratio");
290   }
291   if (options.output_size() != std::nullopt &&
292       options.output_ratio() != std::nullopt) {
293     TORCH_CHECK(
294         false, "only one of output_size and output_ratio may be specified");
295   }
296   if (options.output_ratio() != std::nullopt) {
297     at::ArrayRef<double> output_ratio =
298         at::ArrayRef<double>(options.output_ratio().value());
299     if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] &&
300           output_ratio[1] < 1)) {
301       TORCH_CHECK(
302           false,
303           "output_ratio must be between 0 and 1 (got ",
304           output_ratio,
305           ")");
306     }
307   }
308 }
309 
forward(const Tensor & input)310 Tensor FractionalMaxPool2dImpl::forward(const Tensor& input) {
311   return F::detail::fractional_max_pool2d(
312       input,
313       options.kernel_size(),
314       options.output_size(),
315       options.output_ratio(),
316       _random_samples);
317 }
318 
forward_with_indices(const Tensor & input)319 std::tuple<Tensor, Tensor> FractionalMaxPool2dImpl::forward_with_indices(
320     const Tensor& input) {
321   return F::detail::fractional_max_pool2d_with_indices(
322       input,
323       options.kernel_size(),
324       options.output_size(),
325       options.output_ratio(),
326       _random_samples);
327 }
328 
pretty_print(std::ostream & stream) const329 void FractionalMaxPool2dImpl::pretty_print(std::ostream& stream) const {
330   stream << "torch::nn::FractionalMaxPool2d()";
331 }
332 
FractionalMaxPool3dImpl(FractionalMaxPool3dOptions options_)333 FractionalMaxPool3dImpl::FractionalMaxPool3dImpl(
334     FractionalMaxPool3dOptions options_)
335     : options(std::move(options_)) {
336   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
337   reset();
338 }
339 
reset()340 void FractionalMaxPool3dImpl::reset() {
341   _random_samples =
342       register_buffer("_random_samples", options._random_samples());
343   if (options.output_size() == std::nullopt &&
344       options.output_ratio() == std::nullopt) {
345     TORCH_CHECK(
346         false,
347         "FractionalMaxPool3d requires specifying either ",
348         "an output size, or a pooling ratio");
349   }
350   if (options.output_size() != std::nullopt &&
351       options.output_ratio() != std::nullopt) {
352     TORCH_CHECK(
353         false, "only one of output_size and output_ratio may be specified");
354   }
355   if (options.output_ratio() != std::nullopt) {
356     at::ArrayRef<double> output_ratio =
357         at::ArrayRef<double>(options.output_ratio().value());
358     if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] &&
359           output_ratio[1] < 1 && 0 < output_ratio[2] && output_ratio[2] < 1)) {
360       TORCH_CHECK(
361           false,
362           "output_ratio must be between 0 and 1 (got ",
363           output_ratio,
364           ")");
365     }
366   }
367 }
368 
forward(const Tensor & input)369 Tensor FractionalMaxPool3dImpl::forward(const Tensor& input) {
370   return F::detail::fractional_max_pool3d(
371       input,
372       options.kernel_size(),
373       options.output_size(),
374       options.output_ratio(),
375       _random_samples);
376 }
377 
forward_with_indices(const Tensor & input)378 std::tuple<Tensor, Tensor> FractionalMaxPool3dImpl::forward_with_indices(
379     const Tensor& input) {
380   return F::detail::fractional_max_pool3d_with_indices(
381       input,
382       options.kernel_size(),
383       options.output_size(),
384       options.output_ratio(),
385       _random_samples);
386 }
387 
pretty_print(std::ostream & stream) const388 void FractionalMaxPool3dImpl::pretty_print(std::ostream& stream) const {
389   stream << "torch::nn::FractionalMaxPool3d()";
390 }
391 
392 // ============================================================================
393 
394 template <size_t D, typename Derived>
LPPoolImpl(const LPPoolOptions<D> & options_)395 LPPoolImpl<D, Derived>::LPPoolImpl(const LPPoolOptions<D>& options_)
396     : options(options_) {}
397 
398 template <size_t D, typename Derived>
reset()399 void LPPoolImpl<D, Derived>::reset() {}
400 
401 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const402 void LPPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
403   stream << std::boolalpha << "torch::nn::LPPool" << D << "d("
404          << "norm_type=" << options.norm_type() << ", "
405          << "kernel_size=" << options.kernel_size() << ", "
406          << "stride=" << options.stride() << ", "
407          << "ceil_mode=" << options.ceil_mode() << ")";
408 }
409 
forward(const Tensor & input)410 Tensor LPPool1dImpl::forward(const Tensor& input) {
411   return F::detail::lp_pool1d(
412       input,
413       options.norm_type(),
414       options.kernel_size(),
415       options.stride(),
416       options.ceil_mode());
417 }
418 
419 template class LPPoolImpl<1, LPPool1dImpl>;
420 
forward(const Tensor & input)421 Tensor LPPool2dImpl::forward(const Tensor& input) {
422   return F::detail::lp_pool2d(
423       input,
424       options.norm_type(),
425       options.kernel_size(),
426       options.stride(),
427       options.ceil_mode());
428 }
429 
430 template class LPPoolImpl<2, LPPool2dImpl>;
431 
forward(const Tensor & input)432 Tensor LPPool3dImpl::forward(const Tensor& input) {
433   return F::detail::lp_pool3d(
434       input,
435       options.norm_type(),
436       options.kernel_size(),
437       options.stride(),
438       options.ceil_mode());
439 }
440 
441 template class LPPoolImpl<3, LPPool3dImpl>;
442 
443 } // namespace nn
444 } // namespace torch
445