1 #include <torch/nn/functional/linear.h> 2 #include <torch/nn/init.h> 3 #include <torch/nn/modules/linear.h> 4 5 #include <torch/types.h> 6 #include <torch/utils.h> 7 8 #include <cmath> 9 #include <cstdint> 10 11 namespace F = torch::nn::functional; 12 13 namespace torch { 14 namespace nn { 15 reset()16 void IdentityImpl::reset() {} 17 pretty_print(std::ostream & stream) const18 void IdentityImpl::pretty_print(std::ostream& stream) const { 19 stream << "torch::nn::Identity()"; 20 } 21 forward(const Tensor & input)22 Tensor IdentityImpl::forward(const Tensor& input) { 23 return input; 24 } 25 26 // ============================================================================ 27 LinearImpl(const LinearOptions & options_)28 LinearImpl::LinearImpl(const LinearOptions& options_) : options(options_) { 29 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 30 reset(); 31 } 32 reset()33 void LinearImpl::reset() { 34 weight = register_parameter( 35 "weight", torch::empty({options.out_features(), options.in_features()})); 36 if (options.bias()) { 37 bias = register_parameter("bias", torch::empty(options.out_features())); 38 } else { 39 bias = register_parameter("bias", {}, /*requires_grad=*/false); 40 } 41 42 reset_parameters(); 43 } 44 reset_parameters()45 void LinearImpl::reset_parameters() { 46 torch::nn::init::kaiming_uniform_( 47 weight, std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers) 48 if (bias.defined()) { 49 auto [fan_in, fan_out] = 50 torch::nn::init::_calculate_fan_in_and_fan_out(weight); 51 const auto bound = 1 / std::sqrt(fan_in); 52 torch::nn::init::uniform_(bias, -bound, bound); 53 } 54 } 55 pretty_print(std::ostream & stream) const56 void LinearImpl::pretty_print(std::ostream& stream) const { 57 stream << std::boolalpha 58 << "torch::nn::Linear(in_features=" << options.in_features() 59 << ", out_features=" << options.out_features() 60 << ", bias=" << options.bias() << ")"; 61 } 62 forward(const Tensor & input)63 Tensor LinearImpl::forward(const Tensor& input) { 64 return F::linear(input, weight, bias); 65 } 66 67 // ============================================================================ 68 FlattenImpl(const FlattenOptions & options_)69 FlattenImpl::FlattenImpl(const FlattenOptions& options_) : options(options_) {} 70 reset()71 void FlattenImpl::reset() {} 72 pretty_print(std::ostream & stream) const73 void FlattenImpl::pretty_print(std::ostream& stream) const { 74 stream << "torch::nn::Flatten(start_dim=" << options.start_dim() 75 << ", end_dim=" << options.end_dim() << ")"; 76 } 77 forward(const Tensor & input)78 Tensor FlattenImpl::forward(const Tensor& input) { 79 return input.flatten(options.start_dim(), options.end_dim()); 80 } 81 82 // ============================================================================ 83 UnflattenImpl(UnflattenOptions options_)84 UnflattenImpl::UnflattenImpl(UnflattenOptions options_) 85 : options(std::move(options_)) {} 86 reset()87 void UnflattenImpl::reset() {} 88 pretty_print(std::ostream & stream) const89 void UnflattenImpl::pretty_print(std::ostream& stream) const { 90 auto namedshape = options.namedshape(); 91 if (!namedshape.empty()) { 92 stream << "torch::nn::Unflatten(dim=\"" << options.dimname() 93 << "\", unflattened_size={"; 94 size_t i = 0; 95 for (; i < namedshape.size() - 1; ++i) { 96 stream << "{\"" << std::get<0>(namedshape[i]) << "\", " 97 << std::get<1>(namedshape[i]) << "}, "; 98 } 99 stream << "{\"" << std::get<0>(namedshape[i]) << "\", " 100 << std::get<1>(namedshape[i]) << "}})"; 101 } else { 102 stream << "torch::nn::Unflatten(dim=" << options.dim() 103 << ", unflattened_size={"; 104 auto sizes = options.sizes(); 105 size_t i = 0; 106 for (; i < sizes.size() - 1; ++i) { 107 stream << sizes[i] << ", "; 108 } 109 stream << sizes[i] << "})"; 110 } 111 } 112 forward(const Tensor & input)113 Tensor UnflattenImpl::forward(const Tensor& input) { 114 auto namedshape = options.namedshape(); 115 if (!namedshape.empty()) { 116 auto dimname = 117 torch::Dimname::fromSymbol(torch::Symbol::dimname(options.dimname())); 118 std::vector<int64_t> sizes; 119 std::vector<torch::Dimname> names; 120 for (auto i : namedshape) { 121 names.push_back( 122 torch::Dimname::fromSymbol(torch::Symbol::dimname(std::get<0>(i)))); 123 sizes.push_back(std::get<1>(i)); 124 } 125 return input.unflatten(dimname, sizes, names); 126 } 127 return input.unflatten(options.dim(), options.sizes()); 128 } 129 130 // ============================================================================ 131 BilinearImpl(const BilinearOptions & options_)132 BilinearImpl::BilinearImpl(const BilinearOptions& options_) 133 : options(options_) { 134 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 135 reset(); 136 } 137 reset()138 void BilinearImpl::reset() { 139 weight = register_parameter( 140 "weight", 141 torch::empty( 142 {options.out_features(), 143 options.in1_features(), 144 options.in2_features()})); 145 if (options.bias()) { 146 bias = register_parameter("bias", torch::empty(options.out_features())); 147 } else { 148 bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false); 149 } 150 151 reset_parameters(); 152 } 153 reset_parameters()154 void BilinearImpl::reset_parameters() { 155 const auto bound = 1.0 / std::sqrt(weight.size(1)); 156 init::uniform_(weight, -bound, bound); 157 if (bias.defined()) { 158 init::uniform_(bias, -bound, bound); 159 } 160 } 161 pretty_print(std::ostream & stream) const162 void BilinearImpl::pretty_print(std::ostream& stream) const { 163 stream << std::boolalpha 164 << "torch::nn::Bilinear(in1_features=" << options.in1_features() 165 << ", in2_features=" << options.in2_features() 166 << ", out_features=" << options.out_features() 167 << ", bias=" << options.bias() << ")"; 168 } 169 forward(const Tensor & input1,const Tensor & input2)170 Tensor BilinearImpl::forward(const Tensor& input1, const Tensor& input2) { 171 return F::bilinear(input1, input2, weight, bias); 172 } 173 174 } // namespace nn 175 } // namespace torch 176