1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/types.h> 6 7 namespace torch { 8 namespace nn { 9 10 /// Options for the `Linear` module. 11 /// 12 /// Example: 13 /// ``` 14 /// Linear model(LinearOptions(5, 2).bias(false)); 15 /// ``` 16 struct TORCH_API LinearOptions { 17 LinearOptions(int64_t in_features, int64_t out_features); 18 /// size of each input sample 19 TORCH_ARG(int64_t, in_features); 20 21 /// size of each output sample 22 TORCH_ARG(int64_t, out_features); 23 24 /// If set to false, the layer will not learn an additive bias. Default: true 25 TORCH_ARG(bool, bias) = true; 26 }; 27 28 // ============================================================================ 29 30 /// Options for the `Flatten` module. 31 /// 32 /// Example: 33 /// ``` 34 /// Flatten model(FlattenOptions().start_dim(2).end_dim(4)); 35 /// ``` 36 struct TORCH_API FlattenOptions { 37 /// first dim to flatten 38 TORCH_ARG(int64_t, start_dim) = 1; 39 /// last dim to flatten 40 TORCH_ARG(int64_t, end_dim) = -1; 41 }; 42 43 // ============================================================================ 44 45 /// Options for the `Unflatten` module. 46 /// 47 /// Note: If input tensor is named, use dimname and namedshape arguments. 48 /// 49 /// Example: 50 /// ``` 51 /// Unflatten unnamed_model(UnflattenOptions(0, {2, 2})); 52 /// Unflatten named_model(UnflattenOptions("B", {{"B1", 2}, {"B2", 2}})); 53 /// ``` 54 struct TORCH_API UnflattenOptions { 55 typedef std::vector<std::pair<std::string, int64_t>> namedshape_t; 56 57 UnflattenOptions(int64_t dim, std::vector<int64_t> sizes); 58 UnflattenOptions(const char* dimname, namedshape_t namedshape); 59 UnflattenOptions(std::string dimname, namedshape_t namedshape); 60 61 /// dim to unflatten 62 TORCH_ARG(int64_t, dim); 63 /// name of dim to unflatten, for use with named tensors 64 TORCH_ARG(std::string, dimname); 65 /// new shape of unflattened dim 66 TORCH_ARG(std::vector<int64_t>, sizes); 67 /// new shape of unflattened dim with names, for use with named tensors 68 TORCH_ARG(namedshape_t, namedshape); 69 }; 70 71 // ============================================================================ 72 73 /// Options for the `Bilinear` module. 74 /// 75 /// Example: 76 /// ``` 77 /// Bilinear model(BilinearOptions(3, 2, 4).bias(false)); 78 /// ``` 79 struct TORCH_API BilinearOptions { 80 BilinearOptions( 81 int64_t in1_features, 82 int64_t in2_features, 83 int64_t out_features); 84 /// The number of features in input 1 (columns of the input1 matrix). 85 TORCH_ARG(int64_t, in1_features); 86 /// The number of features in input 2 (columns of the input2 matrix). 87 TORCH_ARG(int64_t, in2_features); 88 /// The number of output features to produce (columns of the output matrix). 89 TORCH_ARG(int64_t, out_features); 90 /// Whether to learn and add a bias after the bilinear transformation. 91 TORCH_ARG(bool, bias) = true; 92 }; 93 94 } // namespace nn 95 } // namespace torch 96