xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/linear.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 
10 /// Options for the `Linear` module.
11 ///
12 /// Example:
13 /// ```
14 /// Linear model(LinearOptions(5, 2).bias(false));
15 /// ```
16 struct TORCH_API LinearOptions {
17   LinearOptions(int64_t in_features, int64_t out_features);
18   /// size of each input sample
19   TORCH_ARG(int64_t, in_features);
20 
21   /// size of each output sample
22   TORCH_ARG(int64_t, out_features);
23 
24   /// If set to false, the layer will not learn an additive bias. Default: true
25   TORCH_ARG(bool, bias) = true;
26 };
27 
28 // ============================================================================
29 
30 /// Options for the `Flatten` module.
31 ///
32 /// Example:
33 /// ```
34 /// Flatten model(FlattenOptions().start_dim(2).end_dim(4));
35 /// ```
36 struct TORCH_API FlattenOptions {
37   /// first dim to flatten
38   TORCH_ARG(int64_t, start_dim) = 1;
39   /// last dim to flatten
40   TORCH_ARG(int64_t, end_dim) = -1;
41 };
42 
43 // ============================================================================
44 
45 /// Options for the `Unflatten` module.
46 ///
47 /// Note: If input tensor is named, use dimname and namedshape arguments.
48 ///
49 /// Example:
50 /// ```
51 /// Unflatten unnamed_model(UnflattenOptions(0, {2, 2}));
52 /// Unflatten named_model(UnflattenOptions("B", {{"B1", 2}, {"B2", 2}}));
53 /// ```
54 struct TORCH_API UnflattenOptions {
55   typedef std::vector<std::pair<std::string, int64_t>> namedshape_t;
56 
57   UnflattenOptions(int64_t dim, std::vector<int64_t> sizes);
58   UnflattenOptions(const char* dimname, namedshape_t namedshape);
59   UnflattenOptions(std::string dimname, namedshape_t namedshape);
60 
61   /// dim to unflatten
62   TORCH_ARG(int64_t, dim);
63   /// name of dim to unflatten, for use with named tensors
64   TORCH_ARG(std::string, dimname);
65   /// new shape of unflattened dim
66   TORCH_ARG(std::vector<int64_t>, sizes);
67   /// new shape of unflattened dim with names, for use with named tensors
68   TORCH_ARG(namedshape_t, namedshape);
69 };
70 
71 // ============================================================================
72 
73 /// Options for the `Bilinear` module.
74 ///
75 /// Example:
76 /// ```
77 /// Bilinear model(BilinearOptions(3, 2, 4).bias(false));
78 /// ```
79 struct TORCH_API BilinearOptions {
80   BilinearOptions(
81       int64_t in1_features,
82       int64_t in2_features,
83       int64_t out_features);
84   /// The number of features in input 1 (columns of the input1 matrix).
85   TORCH_ARG(int64_t, in1_features);
86   /// The number of features in input 2 (columns of the input2 matrix).
87   TORCH_ARG(int64_t, in2_features);
88   /// The number of output features to produce (columns of the output matrix).
89   TORCH_ARG(int64_t, out_features);
90   /// Whether to learn and add a bias after the bilinear transformation.
91   TORCH_ARG(bool, bias) = true;
92 };
93 
94 } // namespace nn
95 } // namespace torch
96