xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/pooling.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/expanding_array.h>
4 #include <torch/nn/cloneable.h>
5 #include <torch/nn/functional/pooling.h>
6 #include <torch/nn/modules/common.h>
7 #include <torch/nn/options/pooling.h>
8 
9 #include <torch/csrc/Export.h>
10 
11 namespace torch {
12 namespace nn {
13 
14 /// Base class for all (dimension-specialized) avgpool modules.
15 template <size_t D, typename Derived>
16 class TORCH_API AvgPoolImpl : public torch::nn::Cloneable<Derived> {
17  public:
AvgPoolImpl(ExpandingArray<D> kernel_size)18   AvgPoolImpl(ExpandingArray<D> kernel_size)
19       : AvgPoolImpl(AvgPoolOptions<D>(kernel_size)) {}
20   explicit AvgPoolImpl(const AvgPoolOptions<D>& options_);
21 
22   void reset() override;
23 
24   /// Pretty prints the `AvgPool{1,2,3}d` module into the given `stream`.
25   void pretty_print(std::ostream& stream) const override;
26 
27   /// The options with which this `Module` was constructed.
28   AvgPoolOptions<D> options;
29 };
30 
31 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32 
33 /// Applies avgpool over a 1-D input.
34 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool1d to learn
35 /// about the exact behavior of this module.
36 ///
37 /// See the documentation for `torch::nn::AvgPool1dOptions` class to learn what
38 /// constructor arguments are supported for this module.
39 ///
40 /// Example:
41 /// ```
42 /// AvgPool1d model(AvgPool1dOptions(3).stride(2));
43 /// ```
44 class TORCH_API AvgPool1dImpl : public AvgPoolImpl<1, AvgPool1dImpl> {
45  public:
46   using AvgPoolImpl<1, AvgPool1dImpl>::AvgPoolImpl;
47   Tensor forward(const Tensor& input);
48 };
49 
50 /// A `ModuleHolder` subclass for `AvgPool1dImpl`.
51 /// See the documentation for `AvgPool1dImpl` class to learn what methods it
52 /// provides, and examples of how to use `AvgPool1d` with
53 /// `torch::nn::AvgPool1dOptions`. See the documentation for `ModuleHolder` to
54 /// learn about PyTorch's module storage semantics.
55 TORCH_MODULE(AvgPool1d);
56 
57 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
58 
59 /// Applies avgpool over a 2-D input.
60 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool2d to learn
61 /// about the exact behavior of this module.
62 ///
63 /// See the documentation for `torch::nn::AvgPool2dOptions` class to learn what
64 /// constructor arguments are supported for this module.
65 ///
66 /// Example:
67 /// ```
68 /// AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2}));
69 /// ```
70 class TORCH_API AvgPool2dImpl : public AvgPoolImpl<2, AvgPool2dImpl> {
71  public:
72   using AvgPoolImpl<2, AvgPool2dImpl>::AvgPoolImpl;
73   Tensor forward(const Tensor& input);
74 };
75 
76 /// A `ModuleHolder` subclass for `AvgPool2dImpl`.
77 /// See the documentation for `AvgPool2dImpl` class to learn what methods it
78 /// provides, and examples of how to use `AvgPool2d` with
79 /// `torch::nn::AvgPool2dOptions`. See the documentation for `ModuleHolder` to
80 /// learn about PyTorch's module storage semantics.
81 TORCH_MODULE(AvgPool2d);
82 
83 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
84 
85 /// Applies avgpool over a 3-D input.
86 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool3d to learn
87 /// about the exact behavior of this module.
88 ///
89 /// See the documentation for `torch::nn::AvgPool3dOptions` class to learn what
90 /// constructor arguments are supported for this module.
91 ///
92 /// Example:
93 /// ```
94 /// AvgPool3d model(AvgPool3dOptions(5).stride(2));
95 /// ```
96 class TORCH_API AvgPool3dImpl : public AvgPoolImpl<3, AvgPool3dImpl> {
97  public:
98   using AvgPoolImpl<3, AvgPool3dImpl>::AvgPoolImpl;
99   Tensor forward(const Tensor& input);
100 };
101 
102 /// A `ModuleHolder` subclass for `AvgPool3dImpl`.
103 /// See the documentation for `AvgPool3dImpl` class to learn what methods it
104 /// provides, and examples of how to use `AvgPool3d` with
105 /// `torch::nn::AvgPool3dOptions`. See the documentation for `ModuleHolder` to
106 /// learn about PyTorch's module storage semantics.
107 TORCH_MODULE(AvgPool3d);
108 
109 // ============================================================================
110 
111 /// Base class for all (dimension-specialized) maxpool modules.
112 template <size_t D, typename Derived>
113 class TORCH_API MaxPoolImpl : public torch::nn::Cloneable<Derived> {
114  public:
MaxPoolImpl(ExpandingArray<D> kernel_size)115   MaxPoolImpl(ExpandingArray<D> kernel_size)
116       : MaxPoolImpl(MaxPoolOptions<D>(kernel_size)) {}
117   explicit MaxPoolImpl(const MaxPoolOptions<D>& options_);
118 
119   void reset() override;
120 
121   /// Pretty prints the `MaxPool{1,2,3}d` module into the given `stream`.
122   void pretty_print(std::ostream& stream) const override;
123 
124   /// The options with which this `Module` was constructed.
125   MaxPoolOptions<D> options;
126 };
127 
128 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
129 
130 /// Applies maxpool over a 1-D input.
131 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool1d to learn
132 /// about the exact behavior of this module.
133 ///
134 /// See the documentation for `torch::nn::MaxPool1dOptions` class to learn what
135 /// constructor arguments are supported for this module.
136 ///
137 /// Example:
138 /// ```
139 /// MaxPool1d model(MaxPool1dOptions(3).stride(2));
140 /// ```
141 class TORCH_API MaxPool1dImpl : public MaxPoolImpl<1, MaxPool1dImpl> {
142  public:
143   using MaxPoolImpl<1, MaxPool1dImpl>::MaxPoolImpl;
144   Tensor forward(const Tensor& input);
145 
146   /// Returns the outputs and the indices of the max values.
147   /// Useful for `torch::nn::MaxUnpool1d` later.
148   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
149 };
150 
151 /// A `ModuleHolder` subclass for `MaxPool1dImpl`.
152 /// See the documentation for `MaxPool1dImpl` class to learn what methods it
153 /// provides, and examples of how to use `MaxPool1d` with
154 /// `torch::nn::MaxPool1dOptions`. See the documentation for `ModuleHolder` to
155 /// learn about PyTorch's module storage semantics.
156 TORCH_MODULE(MaxPool1d);
157 
158 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
159 
160 /// Applies maxpool over a 2-D input.
161 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool2d to learn
162 /// about the exact behavior of this module.
163 ///
164 /// See the documentation for `torch::nn::MaxPool2dOptions` class to learn what
165 /// constructor arguments are supported for this module.
166 ///
167 /// Example:
168 /// ```
169 /// MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2}));
170 /// ```
171 class TORCH_API MaxPool2dImpl : public MaxPoolImpl<2, MaxPool2dImpl> {
172  public:
173   using MaxPoolImpl<2, MaxPool2dImpl>::MaxPoolImpl;
174   Tensor forward(const Tensor& input);
175 
176   /// Returns the outputs and the indices of the max values.
177   /// Useful for `torch::nn::MaxUnpool2d` later.
178   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
179 };
180 
181 /// A `ModuleHolder` subclass for `MaxPool2dImpl`.
182 /// See the documentation for `MaxPool2dImpl` class to learn what methods it
183 /// provides, and examples of how to use `MaxPool2d` with
184 /// `torch::nn::MaxPool2dOptions`. See the documentation for `ModuleHolder` to
185 /// learn about PyTorch's module storage semantics.
186 TORCH_MODULE(MaxPool2d);
187 
188 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
189 
190 /// Applies maxpool over a 3-D input.
191 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool3d to learn
192 /// about the exact behavior of this module.
193 ///
194 /// See the documentation for `torch::nn::MaxPool3dOptions` class to learn what
195 /// constructor arguments are supported for this module.
196 ///
197 /// Example:
198 /// ```
199 /// MaxPool3d model(MaxPool3dOptions(3).stride(2));
200 /// ```
201 class TORCH_API MaxPool3dImpl : public MaxPoolImpl<3, MaxPool3dImpl> {
202  public:
203   using MaxPoolImpl<3, MaxPool3dImpl>::MaxPoolImpl;
204   Tensor forward(const Tensor& input);
205 
206   /// Returns the outputs and the indices of the max values.
207   /// Useful for `torch::nn::MaxUnpool3d` later.
208   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
209 };
210 
211 /// A `ModuleHolder` subclass for `MaxPool3dImpl`.
212 /// See the documentation for `MaxPool3dImpl` class to learn what methods it
213 /// provides, and examples of how to use `MaxPool3d` with
214 /// `torch::nn::MaxPool3dOptions`. See the documentation for `ModuleHolder` to
215 /// learn about PyTorch's module storage semantics.
216 TORCH_MODULE(MaxPool3d);
217 
218 // ============================================================================
219 
220 /// Base class for all (dimension-specialized) adaptive maxpool modules.
221 template <size_t D, typename output_size_t, typename Derived>
222 class TORCH_API AdaptiveMaxPoolImpl : public torch::nn::Cloneable<Derived> {
223  public:
AdaptiveMaxPoolImpl(output_size_t output_size)224   AdaptiveMaxPoolImpl(output_size_t output_size)
225       : AdaptiveMaxPoolImpl(
226             AdaptiveMaxPoolOptions<output_size_t>(output_size)) {}
AdaptiveMaxPoolImpl(const AdaptiveMaxPoolOptions<output_size_t> & options_)227   explicit AdaptiveMaxPoolImpl(
228       const AdaptiveMaxPoolOptions<output_size_t>& options_)
229       : options(options_) {}
230 
reset()231   void reset() override{};
232 
233   /// Pretty prints the `AdaptiveMaxPool{1,2,3}d` module into the given
234   /// `stream`.
pretty_print(std::ostream & stream)235   void pretty_print(std::ostream& stream) const override {
236     stream << "torch::nn::AdaptiveMaxPool" << D << "d"
237            << "(output_size=" << options.output_size() << ")";
238   }
239 
240   /// The options with which this `Module` was constructed.
241   AdaptiveMaxPoolOptions<output_size_t> options;
242 };
243 
244 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
245 
246 /// Applies adaptive maxpool over a 1-D input.
247 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool1d to
248 /// learn about the exact behavior of this module.
249 ///
250 /// See the documentation for `torch::nn::AdaptiveMaxPool1dOptions` class to
251 /// learn what constructor arguments are supported for this module.
252 ///
253 /// Example:
254 /// ```
255 /// AdaptiveMaxPool1d model(AdaptiveMaxPool1dOptions(3));
256 /// ```
257 class TORCH_API AdaptiveMaxPool1dImpl
258     : public AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl> {
259  public:
260   using AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>::
261       AdaptiveMaxPoolImpl;
262 
263   Tensor forward(const Tensor& input);
264 
265   /// Returns the indices along with the outputs.
266   /// Useful to pass to nn.MaxUnpool1d.
267   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
268 };
269 
270 /// A `ModuleHolder` subclass for `AdaptiveMaxPool1dImpl`.
271 /// See the documentation for `AdaptiveMaxPool1dImpl` class to learn what
272 /// methods it provides, and examples of how to use `AdaptiveMaxPool1d` with
273 /// `torch::nn::AdaptiveMaxPool1dOptions`. See the documentation for
274 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
275 TORCH_MODULE(AdaptiveMaxPool1d);
276 
277 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
278 
279 /// Applies adaptive maxpool over a 2-D input.
280 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool2d to
281 /// learn about the exact behavior of this module.
282 ///
283 /// See the documentation for `torch::nn::AdaptiveMaxPool2dOptions` class to
284 /// learn what constructor arguments are supported for this module.
285 ///
286 /// Example:
287 /// ```
288 /// AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
289 /// ```
290 class TORCH_API AdaptiveMaxPool2dImpl : public AdaptiveMaxPoolImpl<
291                                             2,
292                                             ExpandingArrayWithOptionalElem<2>,
293                                             AdaptiveMaxPool2dImpl> {
294  public:
295   using AdaptiveMaxPoolImpl<
296       2,
297       ExpandingArrayWithOptionalElem<2>,
298       AdaptiveMaxPool2dImpl>::AdaptiveMaxPoolImpl;
299 
300   Tensor forward(const Tensor& input);
301 
302   /// Returns the indices along with the outputs.
303   /// Useful to pass to nn.MaxUnpool2d.
304   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
305 };
306 
307 /// A `ModuleHolder` subclass for `AdaptiveMaxPool2dImpl`.
308 /// See the documentation for `AdaptiveMaxPool2dImpl` class to learn what
309 /// methods it provides, and examples of how to use `AdaptiveMaxPool2d` with
310 /// `torch::nn::AdaptiveMaxPool2dOptions`. See the documentation for
311 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
312 TORCH_MODULE(AdaptiveMaxPool2d);
313 
314 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
315 
316 /// Applies adaptive maxpool over a 3-D input.
317 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool3d to
318 /// learn about the exact behavior of this module.
319 ///
320 /// See the documentation for `torch::nn::AdaptiveMaxPool3dOptions` class to
321 /// learn what constructor arguments are supported for this module.
322 ///
323 /// Example:
324 /// ```
325 /// AdaptiveMaxPool3d model(AdaptiveMaxPool3dOptions(3));
326 /// ```
327 class TORCH_API AdaptiveMaxPool3dImpl : public AdaptiveMaxPoolImpl<
328                                             3,
329                                             ExpandingArrayWithOptionalElem<3>,
330                                             AdaptiveMaxPool3dImpl> {
331  public:
332   using AdaptiveMaxPoolImpl<
333       3,
334       ExpandingArrayWithOptionalElem<3>,
335       AdaptiveMaxPool3dImpl>::AdaptiveMaxPoolImpl;
336 
337   Tensor forward(const Tensor& input);
338 
339   /// Returns the indices along with the outputs.
340   /// Useful to pass to nn.MaxUnpool3d.
341   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
342 };
343 
344 /// A `ModuleHolder` subclass for `AdaptiveMaxPool3dImpl`.
345 /// See the documentation for `AdaptiveMaxPool3dImpl` class to learn what
346 /// methods it provides, and examples of how to use `AdaptiveMaxPool3d` with
347 /// `torch::nn::AdaptiveMaxPool3dOptions`. See the documentation for
348 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
349 TORCH_MODULE(AdaptiveMaxPool3d);
350 
351 // ============================================================================
352 
353 /// Base class for all (dimension-specialized) adaptive avgpool modules.
354 template <size_t D, typename output_size_t, typename Derived>
355 class TORCH_API AdaptiveAvgPoolImpl : public torch::nn::Cloneable<Derived> {
356  public:
AdaptiveAvgPoolImpl(output_size_t output_size)357   AdaptiveAvgPoolImpl(output_size_t output_size)
358       : AdaptiveAvgPoolImpl(
359             AdaptiveAvgPoolOptions<output_size_t>(output_size)) {}
AdaptiveAvgPoolImpl(const AdaptiveAvgPoolOptions<output_size_t> & options_)360   explicit AdaptiveAvgPoolImpl(
361       const AdaptiveAvgPoolOptions<output_size_t>& options_)
362       : options(options_) {}
363 
reset()364   void reset() override {}
365 
366   /// Pretty prints the `AdaptiveAvgPool{1,2,3}d` module into the given
367   /// `stream`.
pretty_print(std::ostream & stream)368   void pretty_print(std::ostream& stream) const override {
369     stream << "torch::nn::AdaptiveAvgPool" << D << "d"
370            << "(output_size=" << options.output_size() << ")";
371   }
372 
373   /// The options with which this `Module` was constructed.
374   AdaptiveAvgPoolOptions<output_size_t> options;
375 };
376 
377 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
378 
379 /// Applies adaptive avgpool over a 1-D input.
380 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool1d to
381 /// learn about the exact behavior of this module.
382 ///
383 /// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to
384 /// learn what constructor arguments are supported for this module.
385 ///
386 /// Example:
387 /// ```
388 /// AdaptiveAvgPool1d model(AdaptiveAvgPool1dOptions(5));
389 /// ```
390 class TORCH_API AdaptiveAvgPool1dImpl
391     : public AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl> {
392  public:
393   using AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>::
394       AdaptiveAvgPoolImpl;
395 
396   Tensor forward(const Tensor& input);
397 };
398 
399 /// A `ModuleHolder` subclass for `AdaptiveAvgPool1dImpl`.
400 /// See the documentation for `AdaptiveAvgPool1dImpl` class to learn what
401 /// methods it provides, and examples of how to use `AdaptiveAvgPool1d` with
402 /// `torch::nn::AdaptiveAvgPool1dOptions`. See the documentation for
403 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
404 TORCH_MODULE(AdaptiveAvgPool1d);
405 
406 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407 
408 /// Applies adaptive avgpool over a 2-D input.
409 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool2d to
410 /// learn about the exact behavior of this module.
411 ///
412 /// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to
413 /// learn what constructor arguments are supported for this module.
414 ///
415 /// Example:
416 /// ```
417 /// AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2}));
418 /// ```
419 class TORCH_API AdaptiveAvgPool2dImpl : public AdaptiveAvgPoolImpl<
420                                             2,
421                                             ExpandingArrayWithOptionalElem<2>,
422                                             AdaptiveAvgPool2dImpl> {
423  public:
424   using AdaptiveAvgPoolImpl<
425       2,
426       ExpandingArrayWithOptionalElem<2>,
427       AdaptiveAvgPool2dImpl>::AdaptiveAvgPoolImpl;
428 
429   Tensor forward(const Tensor& input);
430 };
431 
432 /// A `ModuleHolder` subclass for `AdaptiveAvgPool2dImpl`.
433 /// See the documentation for `AdaptiveAvgPool2dImpl` class to learn what
434 /// methods it provides, and examples of how to use `AdaptiveAvgPool2d` with
435 /// `torch::nn::AdaptiveAvgPool2dOptions`. See the documentation for
436 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
437 TORCH_MODULE(AdaptiveAvgPool2d);
438 
439 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
440 
441 /// Applies adaptive avgpool over a 3-D input.
442 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool3d to
443 /// learn about the exact behavior of this module.
444 ///
445 /// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to
446 /// learn what constructor arguments are supported for this module.
447 ///
448 /// Example:
449 /// ```
450 /// AdaptiveAvgPool3d model(AdaptiveAvgPool3dOptions(3));
451 /// ```
452 class TORCH_API AdaptiveAvgPool3dImpl : public AdaptiveAvgPoolImpl<
453                                             3,
454                                             ExpandingArrayWithOptionalElem<3>,
455                                             AdaptiveAvgPool3dImpl> {
456  public:
457   using AdaptiveAvgPoolImpl<
458       3,
459       ExpandingArrayWithOptionalElem<3>,
460       AdaptiveAvgPool3dImpl>::AdaptiveAvgPoolImpl;
461 
462   Tensor forward(const Tensor& input);
463 };
464 
465 /// A `ModuleHolder` subclass for `AdaptiveAvgPool3dImpl`.
466 /// See the documentation for `AdaptiveAvgPool3dImpl` class to learn what
467 /// methods it provides, and examples of how to use `AdaptiveAvgPool3d` with
468 /// `torch::nn::AdaptiveAvgPool3dOptions`. See the documentation for
469 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
470 TORCH_MODULE(AdaptiveAvgPool3d);
471 
472 // ============================================================================
473 
474 /// Base class for all (dimension-specialized) maxunpool modules.
475 template <size_t D, typename Derived>
476 class TORCH_API MaxUnpoolImpl : public torch::nn::Cloneable<Derived> {
477  public:
MaxUnpoolImpl(ExpandingArray<D> kernel_size)478   MaxUnpoolImpl(ExpandingArray<D> kernel_size)
479       : MaxUnpoolImpl(MaxUnpoolOptions<D>(kernel_size)) {}
480   explicit MaxUnpoolImpl(const MaxUnpoolOptions<D>& options_);
481 
482   void reset() override;
483 
484   /// Pretty prints the `MaxUnpool{1,2,3}d` module into the given `stream`.
485   void pretty_print(std::ostream& stream) const override;
486 
487   /// The options with which this `Module` was constructed.
488   MaxUnpoolOptions<D> options;
489 };
490 
491 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
492 
493 /// Applies maxunpool over a 1-D input.
494 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool1d to learn
495 /// about the exact behavior of this module.
496 ///
497 /// See the documentation for `torch::nn::MaxUnpool1dOptions` class to learn
498 /// what constructor arguments are supported for this module.
499 ///
500 /// Example:
501 /// ```
502 /// MaxUnpool1d model(MaxUnpool1dOptions(3).stride(2).padding(1));
503 /// ```
504 class TORCH_API MaxUnpool1dImpl : public MaxUnpoolImpl<1, MaxUnpool1dImpl> {
505  public:
506   using MaxUnpoolImpl<1, MaxUnpool1dImpl>::MaxUnpoolImpl;
507   Tensor forward(
508       const Tensor& input,
509       const Tensor& indices,
510       const std::optional<std::vector<int64_t>>& output_size = std::nullopt);
511 
512  protected:
513   FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional<std::vector<int64_t>>())})
514 };
515 
516 /// A `ModuleHolder` subclass for `MaxUnpool1dImpl`.
517 /// See the documentation for `MaxUnpool1dImpl` class to learn what methods it
518 /// provides, and examples of how to use `MaxUnpool1d` with
519 /// `torch::nn::MaxUnpool1dOptions`. See the documentation for `ModuleHolder` to
520 /// learn about PyTorch's module storage semantics.
521 TORCH_MODULE(MaxUnpool1d);
522 
523 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
524 
525 /// Applies maxunpool over a 2-D input.
526 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool2d to learn
527 /// about the exact behavior of this module.
528 ///
529 /// See the documentation for `torch::nn::MaxUnpool2dOptions` class to learn
530 /// what constructor arguments are supported for this module.
531 ///
532 /// Example:
533 /// ```
534 /// MaxUnpool2d model(MaxUnpool2dOptions(3).stride(2).padding(1));
535 /// ```
536 class TORCH_API MaxUnpool2dImpl : public MaxUnpoolImpl<2, MaxUnpool2dImpl> {
537  public:
538   using MaxUnpoolImpl<2, MaxUnpool2dImpl>::MaxUnpoolImpl;
539   Tensor forward(
540       const Tensor& input,
541       const Tensor& indices,
542       const std::optional<std::vector<int64_t>>& output_size = std::nullopt);
543 
544  protected:
545   FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional<std::vector<int64_t>>())})
546 };
547 
548 /// A `ModuleHolder` subclass for `MaxUnpool2dImpl`.
549 /// See the documentation for `MaxUnpool2dImpl` class to learn what methods it
550 /// provides, and examples of how to use `MaxUnpool2d` with
551 /// `torch::nn::MaxUnpool2dOptions`. See the documentation for `ModuleHolder` to
552 /// learn about PyTorch's module storage semantics.
553 TORCH_MODULE(MaxUnpool2d);
554 
555 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
556 
557 /// Applies maxunpool over a 3-D input.
558 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool3d to learn
559 /// about the exact behavior of this module.
560 ///
561 /// See the documentation for `torch::nn::MaxUnpool3dOptions` class to learn
562 /// what constructor arguments are supported for this module.
563 ///
564 /// Example:
565 /// ```
566 /// MaxUnpool3d model(MaxUnpool3dOptions(3).stride(2).padding(1));
567 /// ```
568 class TORCH_API MaxUnpool3dImpl : public MaxUnpoolImpl<3, MaxUnpool3dImpl> {
569  public:
570   using MaxUnpoolImpl<3, MaxUnpool3dImpl>::MaxUnpoolImpl;
571   Tensor forward(
572       const Tensor& input,
573       const Tensor& indices,
574       const std::optional<std::vector<int64_t>>& output_size = std::nullopt);
575 
576  protected:
577   FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional<std::vector<int64_t>>())})
578 };
579 
580 /// A `ModuleHolder` subclass for `MaxUnpool3dImpl`.
581 /// See the documentation for `MaxUnpool3dImpl` class to learn what methods it
582 /// provides, and examples of how to use `MaxUnpool3d` with
583 /// `torch::nn::MaxUnpool3dOptions`. See the documentation for `ModuleHolder` to
584 /// learn about PyTorch's module storage semantics.
585 TORCH_MODULE(MaxUnpool3d);
586 
587 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool2d
588 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
589 
590 /// Applies fractional maxpool over a 2-D input.
591 /// See https://pytorch.org/docs/main/nn.html#torch.nn.FractionalMaxPool2d to
592 /// learn about the exact behavior of this module.
593 ///
594 /// See the documentation for `torch::nn::FractionalMaxPool2dOptions` class to
595 /// learn what constructor arguments are supported for this module.
596 ///
597 /// Example:
598 /// ```
599 /// FractionalMaxPool2d model(FractionalMaxPool2dOptions(5).output_size(1));
600 /// ```
601 class TORCH_API FractionalMaxPool2dImpl
602     : public torch::nn::Cloneable<FractionalMaxPool2dImpl> {
603  public:
FractionalMaxPool2dImpl(ExpandingArray<2> kernel_size)604   FractionalMaxPool2dImpl(ExpandingArray<2> kernel_size)
605       : FractionalMaxPool2dImpl(FractionalMaxPool2dOptions(kernel_size)) {}
606   explicit FractionalMaxPool2dImpl(FractionalMaxPool2dOptions options_);
607 
608   void reset() override;
609 
610   /// Pretty prints the `FractionalMaxPool2d` module into the given `stream`.
611   void pretty_print(std::ostream& stream) const override;
612 
613   Tensor forward(const Tensor& input);
614 
615   /// Returns the outputs and the indices of the max values.
616   /// Useful for `torch::nn::MaxUnpool2d` later.
617   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
618 
619   /// The options with which this `Module` was constructed.
620   FractionalMaxPool2dOptions options;
621 
622   Tensor _random_samples;
623 };
624 
625 /// A `ModuleHolder` subclass for `FractionalMaxPool2dImpl`.
626 /// See the documentation for `FractionalMaxPool2dImpl` class to learn what
627 /// methods it provides, and examples of how to use `FractionalMaxPool2d` with
628 /// `torch::nn::FractionalMaxPool2dOptions`. See the documentation for
629 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
630 TORCH_MODULE(FractionalMaxPool2d);
631 
632 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool3d
633 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
634 
635 /// Applies fractional maxpool over a 3-D input.
636 /// See https://pytorch.org/docs/main/nn.html#torch.nn.FractionalMaxPool3d to
637 /// learn about the exact behavior of this module.
638 ///
639 /// See the documentation for `torch::nn::FractionalMaxPool3dOptions` class to
640 /// learn what constructor arguments are supported for this module.
641 ///
642 /// Example:
643 /// ```
644 /// FractionalMaxPool3d model(FractionalMaxPool3dOptions(5).output_size(1));
645 /// ```
646 class TORCH_API FractionalMaxPool3dImpl
647     : public torch::nn::Cloneable<FractionalMaxPool3dImpl> {
648  public:
FractionalMaxPool3dImpl(ExpandingArray<3> kernel_size)649   FractionalMaxPool3dImpl(ExpandingArray<3> kernel_size)
650       : FractionalMaxPool3dImpl(FractionalMaxPool3dOptions(kernel_size)) {}
651   explicit FractionalMaxPool3dImpl(FractionalMaxPool3dOptions options_);
652 
653   void reset() override;
654 
655   /// Pretty prints the `FractionalMaxPool3d` module into the given `stream`.
656   void pretty_print(std::ostream& stream) const override;
657 
658   Tensor forward(const Tensor& input);
659 
660   /// Returns the outputs and the indices of the max values.
661   /// Useful for `torch::nn::MaxUnpool3d` later.
662   std::tuple<Tensor, Tensor> forward_with_indices(const Tensor& input);
663 
664   /// The options with which this `Module` was constructed.
665   FractionalMaxPool3dOptions options;
666 
667   Tensor _random_samples;
668 };
669 
670 /// A `ModuleHolder` subclass for `FractionalMaxPool3dImpl`.
671 /// See the documentation for `FractionalMaxPool3dImpl` class to learn what
672 /// methods it provides, and examples of how to use `FractionalMaxPool3d` with
673 /// `torch::nn::FractionalMaxPool3dOptions`. See the documentation for
674 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
675 TORCH_MODULE(FractionalMaxPool3d);
676 
677 // ============================================================================
678 
679 /// Base class for all (dimension-specialized) lppool modules.
680 template <size_t D, typename Derived>
681 class TORCH_API LPPoolImpl : public torch::nn::Cloneable<Derived> {
682  public:
LPPoolImpl(double norm_type,ExpandingArray<D> kernel_size)683   LPPoolImpl(double norm_type, ExpandingArray<D> kernel_size)
684       : LPPoolImpl(LPPoolOptions<D>(norm_type, kernel_size)) {}
685   explicit LPPoolImpl(const LPPoolOptions<D>& options_);
686 
687   void reset() override;
688 
689   /// Pretty prints the `LPPool{1,2}d` module into the given `stream`.
690   void pretty_print(std::ostream& stream) const override;
691 
692   LPPoolOptions<D> options;
693 };
694 
695 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
696 
697 /// Applies the LPPool1d function element-wise.
698 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool1d to learn
699 /// about the exact behavior of this module.
700 ///
701 /// See the documentation for `torch::nn::LPPool1dOptions` class to learn what
702 /// constructor arguments are supported for this module.
703 ///
704 /// Example:
705 /// ```
706 /// LPPool1d model(LPPool1dOptions(1, 2).stride(5).ceil_mode(true));
707 /// ```
708 class TORCH_API LPPool1dImpl : public LPPoolImpl<1, LPPool1dImpl> {
709  public:
710   using LPPoolImpl<1, LPPool1dImpl>::LPPoolImpl;
711 
712   Tensor forward(const Tensor& input);
713 };
714 
715 /// A `ModuleHolder` subclass for `LPPool1dImpl`.
716 /// See the documentation for `LPPool1dImpl` class to learn what methods it
717 /// provides, and examples of how to use `LPPool1d` with
718 /// `torch::nn::LPPool1dOptions`. See the documentation for `ModuleHolder` to
719 /// learn about PyTorch's module storage semantics.
720 TORCH_MODULE(LPPool1d);
721 
722 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
723 
724 /// Applies the LPPool2d function element-wise.
725 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool2d to learn
726 /// about the exact behavior of this module.
727 ///
728 /// See the documentation for `torch::nn::LPPool2dOptions` class to learn what
729 /// constructor arguments are supported for this module.
730 ///
731 /// Example:
732 /// ```
733 /// LPPool2d model(LPPool2dOptions(1, std::vector<int64_t>({3, 4})).stride({5,
734 /// 6}).ceil_mode(true));
735 /// ```
736 class TORCH_API LPPool2dImpl : public LPPoolImpl<2, LPPool2dImpl> {
737  public:
738   using LPPoolImpl<2, LPPool2dImpl>::LPPoolImpl;
739 
740   Tensor forward(const Tensor& input);
741 };
742 
743 /// A `ModuleHolder` subclass for `LPPool2dImpl`.
744 /// See the documentation for `LPPool2dImpl` class to learn what methods it
745 /// provides, and examples of how to use `LPPool2d` with
746 /// `torch::nn::LPPool2dOptions`. See the documentation for `ModuleHolder` to
747 /// learn about PyTorch's module storage semantics.
748 TORCH_MODULE(LPPool2d);
749 
750 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
751 
752 /// Applies the LPPool3d function element-wise.
753 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool3d to learn
754 /// about the exact behavior of this module.
755 ///
756 /// See the documentation for `torch::nn::LPPool3dOptions` class to learn what
757 /// constructor arguments are supported for this module.
758 ///
759 /// Example:
760 /// ```
761 /// LPPool3d model(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5})).stride(
762 /// {5, 6, 7}).ceil_mode(true));
763 /// ```
764 class TORCH_API LPPool3dImpl : public LPPoolImpl<3, LPPool3dImpl> {
765  public:
766   using LPPoolImpl<3, LPPool3dImpl>::LPPoolImpl;
767 
768   Tensor forward(const Tensor& input);
769 };
770 
771 /// A `ModuleHolder` subclass for `LPPool3dImpl`.
772 /// See the documentation for `LPPool3dImpl` class to learn what methods it
773 /// provides, and examples of how to use `LPPool3d` with
774 /// `torch::nn::LPPool3dOptions`. See the documentation for `ModuleHolder` to
775 /// learn about PyTorch's module storage semantics.
776 TORCH_MODULE(LPPool3d);
777 
778 } // namespace nn
779 } // namespace torch
780