1 #include <torch/nn/modules/pooling.h>
2
3 #include <torch/expanding_array.h>
4
5 namespace F = torch::nn::functional;
6
7 namespace torch {
8 namespace nn {
9
10 template <size_t D, typename Derived>
AvgPoolImpl(const AvgPoolOptions<D> & options_)11 AvgPoolImpl<D, Derived>::AvgPoolImpl(const AvgPoolOptions<D>& options_)
12 : options(options_) {}
13
14 template <size_t D, typename Derived>
reset()15 void AvgPoolImpl<D, Derived>::reset() {}
16
17 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const18 void AvgPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
19 stream << "torch::nn::AvgPool" << D << "d"
20 << "(kernel_size=" << options.kernel_size()
21 << ", stride=" << options.stride() << ", padding=" << options.padding()
22 << ")";
23 }
24
forward(const Tensor & input)25 Tensor AvgPool1dImpl::forward(const Tensor& input) {
26 return F::detail::avg_pool1d(
27 input,
28 options.kernel_size(),
29 options.stride(),
30 options.padding(),
31 options.ceil_mode(),
32 options.count_include_pad());
33 }
34
forward(const Tensor & input)35 Tensor AvgPool2dImpl::forward(const Tensor& input) {
36 return F::detail::avg_pool2d(
37 input,
38 options.kernel_size(),
39 options.stride(),
40 options.padding(),
41 options.ceil_mode(),
42 options.count_include_pad(),
43 options.divisor_override());
44 }
45
forward(const Tensor & input)46 Tensor AvgPool3dImpl::forward(const Tensor& input) {
47 return F::detail::avg_pool3d(
48 input,
49 options.kernel_size(),
50 options.stride(),
51 options.padding(),
52 options.ceil_mode(),
53 options.count_include_pad(),
54 options.divisor_override());
55 }
56
57 template class AvgPoolImpl<1, AvgPool1dImpl>;
58 template class AvgPoolImpl<2, AvgPool2dImpl>;
59 template class AvgPoolImpl<3, AvgPool3dImpl>;
60
61 // ============================================================================
62
63 template <size_t D, typename Derived>
MaxPoolImpl(const MaxPoolOptions<D> & options_)64 MaxPoolImpl<D, Derived>::MaxPoolImpl(const MaxPoolOptions<D>& options_)
65 : options(options_) {}
66
67 template <size_t D, typename Derived>
reset()68 void MaxPoolImpl<D, Derived>::reset() {}
69
70 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const71 void MaxPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
72 stream << std::boolalpha << "torch::nn::MaxPool" << D << "d"
73 << "(kernel_size=" << options.kernel_size()
74 << ", stride=" << options.stride() << ", padding=" << options.padding()
75 << ", dilation=" << options.dilation()
76 << ", ceil_mode=" << options.ceil_mode() << ")";
77 }
78
forward(const Tensor & input)79 Tensor MaxPool1dImpl::forward(const Tensor& input) {
80 return F::detail::max_pool1d(
81 input,
82 options.kernel_size(),
83 options.stride(),
84 options.padding(),
85 options.dilation(),
86 options.ceil_mode());
87 }
88
forward_with_indices(const Tensor & input)89 std::tuple<Tensor, Tensor> MaxPool1dImpl::forward_with_indices(
90 const Tensor& input) {
91 return F::detail::max_pool1d_with_indices(
92 input,
93 options.kernel_size(),
94 options.stride(),
95 options.padding(),
96 options.dilation(),
97 options.ceil_mode());
98 }
99
forward(const Tensor & input)100 Tensor MaxPool2dImpl::forward(const Tensor& input) {
101 return F::detail::max_pool2d(
102 input,
103 options.kernel_size(),
104 options.stride(),
105 options.padding(),
106 options.dilation(),
107 options.ceil_mode());
108 }
109
forward_with_indices(const Tensor & input)110 std::tuple<Tensor, Tensor> MaxPool2dImpl::forward_with_indices(
111 const Tensor& input) {
112 return F::detail::max_pool2d_with_indices(
113 input,
114 options.kernel_size(),
115 options.stride(),
116 options.padding(),
117 options.dilation(),
118 options.ceil_mode());
119 }
120
forward(const Tensor & input)121 Tensor MaxPool3dImpl::forward(const Tensor& input) {
122 return F::detail::max_pool3d(
123 input,
124 options.kernel_size(),
125 options.stride(),
126 options.padding(),
127 options.dilation(),
128 options.ceil_mode());
129 }
130
forward_with_indices(const Tensor & input)131 std::tuple<Tensor, Tensor> MaxPool3dImpl::forward_with_indices(
132 const Tensor& input) {
133 return F::detail::max_pool3d_with_indices(
134 input,
135 options.kernel_size(),
136 options.stride(),
137 options.padding(),
138 options.dilation(),
139 options.ceil_mode());
140 }
141
142 template class MaxPoolImpl<1, MaxPool1dImpl>;
143 template class MaxPoolImpl<2, MaxPool2dImpl>;
144 template class MaxPoolImpl<3, MaxPool3dImpl>;
145
146 // ============================================================================
147
forward(const Tensor & input)148 Tensor AdaptiveMaxPool1dImpl::forward(const Tensor& input) {
149 return F::detail::adaptive_max_pool1d(input, options.output_size());
150 }
151
forward_with_indices(const Tensor & input)152 std::tuple<Tensor, Tensor> AdaptiveMaxPool1dImpl::forward_with_indices(
153 const Tensor& input) {
154 return F::detail::adaptive_max_pool1d_with_indices(
155 input, options.output_size());
156 }
157
forward(const Tensor & input)158 Tensor AdaptiveMaxPool2dImpl::forward(const Tensor& input) {
159 return F::detail::adaptive_max_pool2d(input, options.output_size());
160 }
161
forward_with_indices(const Tensor & input)162 std::tuple<Tensor, Tensor> AdaptiveMaxPool2dImpl::forward_with_indices(
163 const Tensor& input) {
164 return F::detail::adaptive_max_pool2d_with_indices(
165 input, options.output_size());
166 }
167
forward(const Tensor & input)168 Tensor AdaptiveMaxPool3dImpl::forward(const Tensor& input) {
169 return F::detail::adaptive_max_pool3d(input, options.output_size());
170 }
171
forward_with_indices(const Tensor & input)172 std::tuple<Tensor, Tensor> AdaptiveMaxPool3dImpl::forward_with_indices(
173 const Tensor& input) {
174 return F::detail::adaptive_max_pool3d_with_indices(
175 input, options.output_size());
176 }
177
178 template class AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>;
179 template class AdaptiveMaxPoolImpl<
180 2,
181 ExpandingArrayWithOptionalElem<2>,
182 AdaptiveMaxPool2dImpl>;
183 template class AdaptiveMaxPoolImpl<
184 3,
185 ExpandingArrayWithOptionalElem<3>,
186 AdaptiveMaxPool3dImpl>;
187
188 // ============================================================================
189
forward(const Tensor & input)190 Tensor AdaptiveAvgPool1dImpl::forward(const Tensor& input) {
191 return F::detail::adaptive_avg_pool1d(input, options.output_size());
192 }
193
forward(const Tensor & input)194 Tensor AdaptiveAvgPool2dImpl::forward(const Tensor& input) {
195 return F::detail::adaptive_avg_pool2d(input, options.output_size());
196 }
197
forward(const Tensor & input)198 Tensor AdaptiveAvgPool3dImpl::forward(const Tensor& input) {
199 return F::detail::adaptive_avg_pool3d(input, options.output_size());
200 }
201
202 template class AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>;
203 template class AdaptiveAvgPoolImpl<
204 2,
205 ExpandingArrayWithOptionalElem<2>,
206 AdaptiveAvgPool2dImpl>;
207 template class AdaptiveAvgPoolImpl<
208 3,
209 ExpandingArrayWithOptionalElem<3>,
210 AdaptiveAvgPool3dImpl>;
211
212 // ============================================================================
213
214 template <size_t D, typename Derived>
MaxUnpoolImpl(const MaxUnpoolOptions<D> & options_)215 MaxUnpoolImpl<D, Derived>::MaxUnpoolImpl(const MaxUnpoolOptions<D>& options_)
216 : options(options_) {}
217
218 template <size_t D, typename Derived>
reset()219 void MaxUnpoolImpl<D, Derived>::reset() {}
220
221 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const222 void MaxUnpoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
223 stream << std::boolalpha << "torch::nn::MaxUnpool" << D << "d"
224 << "(kernel_size=" << options.kernel_size()
225 << ", stride=" << options.stride() << ", padding=" << options.padding()
226 << ")";
227 }
228
forward(const Tensor & input,const Tensor & indices,const std::optional<std::vector<int64_t>> & output_size)229 Tensor MaxUnpool1dImpl::forward(
230 const Tensor& input,
231 const Tensor& indices,
232 const std::optional<std::vector<int64_t>>& output_size) {
233 return F::detail::max_unpool1d(
234 input,
235 indices,
236 options.kernel_size(),
237 options.stride(),
238 options.padding(),
239 output_size);
240 }
241
forward(const Tensor & input,const Tensor & indices,const std::optional<std::vector<int64_t>> & output_size)242 Tensor MaxUnpool2dImpl::forward(
243 const Tensor& input,
244 const Tensor& indices,
245 const std::optional<std::vector<int64_t>>& output_size) {
246 return F::detail::max_unpool2d(
247 input,
248 indices,
249 options.kernel_size(),
250 options.stride(),
251 options.padding(),
252 output_size);
253 }
254
forward(const Tensor & input,const Tensor & indices,const std::optional<std::vector<int64_t>> & output_size)255 Tensor MaxUnpool3dImpl::forward(
256 const Tensor& input,
257 const Tensor& indices,
258 const std::optional<std::vector<int64_t>>& output_size) {
259 return F::detail::max_unpool3d(
260 input,
261 indices,
262 options.kernel_size(),
263 options.stride(),
264 options.padding(),
265 output_size);
266 }
267
268 template class MaxUnpoolImpl<1, MaxUnpool1dImpl>;
269 template class MaxUnpoolImpl<2, MaxUnpool2dImpl>;
270 template class MaxUnpoolImpl<3, MaxUnpool3dImpl>;
271
272 // ============================================================================
273
FractionalMaxPool2dImpl(FractionalMaxPool2dOptions options_)274 FractionalMaxPool2dImpl::FractionalMaxPool2dImpl(
275 FractionalMaxPool2dOptions options_)
276 : options(std::move(options_)) {
277 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
278 reset();
279 }
280
reset()281 void FractionalMaxPool2dImpl::reset() {
282 _random_samples =
283 register_buffer("_random_samples", options._random_samples());
284 if (options.output_size() == std::nullopt &&
285 options.output_ratio() == std::nullopt) {
286 TORCH_CHECK(
287 false,
288 "FractionalMaxPool2d requires specifying either ",
289 "an output size, or a pooling ratio");
290 }
291 if (options.output_size() != std::nullopt &&
292 options.output_ratio() != std::nullopt) {
293 TORCH_CHECK(
294 false, "only one of output_size and output_ratio may be specified");
295 }
296 if (options.output_ratio() != std::nullopt) {
297 at::ArrayRef<double> output_ratio =
298 at::ArrayRef<double>(options.output_ratio().value());
299 if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] &&
300 output_ratio[1] < 1)) {
301 TORCH_CHECK(
302 false,
303 "output_ratio must be between 0 and 1 (got ",
304 output_ratio,
305 ")");
306 }
307 }
308 }
309
forward(const Tensor & input)310 Tensor FractionalMaxPool2dImpl::forward(const Tensor& input) {
311 return F::detail::fractional_max_pool2d(
312 input,
313 options.kernel_size(),
314 options.output_size(),
315 options.output_ratio(),
316 _random_samples);
317 }
318
forward_with_indices(const Tensor & input)319 std::tuple<Tensor, Tensor> FractionalMaxPool2dImpl::forward_with_indices(
320 const Tensor& input) {
321 return F::detail::fractional_max_pool2d_with_indices(
322 input,
323 options.kernel_size(),
324 options.output_size(),
325 options.output_ratio(),
326 _random_samples);
327 }
328
pretty_print(std::ostream & stream) const329 void FractionalMaxPool2dImpl::pretty_print(std::ostream& stream) const {
330 stream << "torch::nn::FractionalMaxPool2d()";
331 }
332
FractionalMaxPool3dImpl(FractionalMaxPool3dOptions options_)333 FractionalMaxPool3dImpl::FractionalMaxPool3dImpl(
334 FractionalMaxPool3dOptions options_)
335 : options(std::move(options_)) {
336 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
337 reset();
338 }
339
reset()340 void FractionalMaxPool3dImpl::reset() {
341 _random_samples =
342 register_buffer("_random_samples", options._random_samples());
343 if (options.output_size() == std::nullopt &&
344 options.output_ratio() == std::nullopt) {
345 TORCH_CHECK(
346 false,
347 "FractionalMaxPool3d requires specifying either ",
348 "an output size, or a pooling ratio");
349 }
350 if (options.output_size() != std::nullopt &&
351 options.output_ratio() != std::nullopt) {
352 TORCH_CHECK(
353 false, "only one of output_size and output_ratio may be specified");
354 }
355 if (options.output_ratio() != std::nullopt) {
356 at::ArrayRef<double> output_ratio =
357 at::ArrayRef<double>(options.output_ratio().value());
358 if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] &&
359 output_ratio[1] < 1 && 0 < output_ratio[2] && output_ratio[2] < 1)) {
360 TORCH_CHECK(
361 false,
362 "output_ratio must be between 0 and 1 (got ",
363 output_ratio,
364 ")");
365 }
366 }
367 }
368
forward(const Tensor & input)369 Tensor FractionalMaxPool3dImpl::forward(const Tensor& input) {
370 return F::detail::fractional_max_pool3d(
371 input,
372 options.kernel_size(),
373 options.output_size(),
374 options.output_ratio(),
375 _random_samples);
376 }
377
forward_with_indices(const Tensor & input)378 std::tuple<Tensor, Tensor> FractionalMaxPool3dImpl::forward_with_indices(
379 const Tensor& input) {
380 return F::detail::fractional_max_pool3d_with_indices(
381 input,
382 options.kernel_size(),
383 options.output_size(),
384 options.output_ratio(),
385 _random_samples);
386 }
387
pretty_print(std::ostream & stream) const388 void FractionalMaxPool3dImpl::pretty_print(std::ostream& stream) const {
389 stream << "torch::nn::FractionalMaxPool3d()";
390 }
391
392 // ============================================================================
393
394 template <size_t D, typename Derived>
LPPoolImpl(const LPPoolOptions<D> & options_)395 LPPoolImpl<D, Derived>::LPPoolImpl(const LPPoolOptions<D>& options_)
396 : options(options_) {}
397
398 template <size_t D, typename Derived>
reset()399 void LPPoolImpl<D, Derived>::reset() {}
400
401 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const402 void LPPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
403 stream << std::boolalpha << "torch::nn::LPPool" << D << "d("
404 << "norm_type=" << options.norm_type() << ", "
405 << "kernel_size=" << options.kernel_size() << ", "
406 << "stride=" << options.stride() << ", "
407 << "ceil_mode=" << options.ceil_mode() << ")";
408 }
409
forward(const Tensor & input)410 Tensor LPPool1dImpl::forward(const Tensor& input) {
411 return F::detail::lp_pool1d(
412 input,
413 options.norm_type(),
414 options.kernel_size(),
415 options.stride(),
416 options.ceil_mode());
417 }
418
419 template class LPPoolImpl<1, LPPool1dImpl>;
420
forward(const Tensor & input)421 Tensor LPPool2dImpl::forward(const Tensor& input) {
422 return F::detail::lp_pool2d(
423 input,
424 options.norm_type(),
425 options.kernel_size(),
426 options.stride(),
427 options.ceil_mode());
428 }
429
430 template class LPPoolImpl<2, LPPool2dImpl>;
431
forward(const Tensor & input)432 Tensor LPPool3dImpl::forward(const Tensor& input) {
433 return F::detail::lp_pool3d(
434 input,
435 options.norm_type(),
436 options.kernel_size(),
437 options.stride(),
438 options.ceil_mode());
439 }
440
441 template class LPPoolImpl<3, LPPool3dImpl>;
442
443 } // namespace nn
444 } // namespace torch
445