xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/linear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <torch/nn/functional/linear.h>
2  #include <torch/nn/init.h>
3  #include <torch/nn/modules/linear.h>
4  
5  #include <torch/types.h>
6  #include <torch/utils.h>
7  
8  #include <cmath>
9  #include <cstdint>
10  
11  namespace F = torch::nn::functional;
12  
13  namespace torch {
14  namespace nn {
15  
reset()16  void IdentityImpl::reset() {}
17  
pretty_print(std::ostream & stream) const18  void IdentityImpl::pretty_print(std::ostream& stream) const {
19    stream << "torch::nn::Identity()";
20  }
21  
forward(const Tensor & input)22  Tensor IdentityImpl::forward(const Tensor& input) {
23    return input;
24  }
25  
26  // ============================================================================
27  
LinearImpl(const LinearOptions & options_)28  LinearImpl::LinearImpl(const LinearOptions& options_) : options(options_) {
29    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
30    reset();
31  }
32  
reset()33  void LinearImpl::reset() {
34    weight = register_parameter(
35        "weight", torch::empty({options.out_features(), options.in_features()}));
36    if (options.bias()) {
37      bias = register_parameter("bias", torch::empty(options.out_features()));
38    } else {
39      bias = register_parameter("bias", {}, /*requires_grad=*/false);
40    }
41  
42    reset_parameters();
43  }
44  
reset_parameters()45  void LinearImpl::reset_parameters() {
46    torch::nn::init::kaiming_uniform_(
47        weight, std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
48    if (bias.defined()) {
49      auto [fan_in, fan_out] =
50          torch::nn::init::_calculate_fan_in_and_fan_out(weight);
51      const auto bound = 1 / std::sqrt(fan_in);
52      torch::nn::init::uniform_(bias, -bound, bound);
53    }
54  }
55  
pretty_print(std::ostream & stream) const56  void LinearImpl::pretty_print(std::ostream& stream) const {
57    stream << std::boolalpha
58           << "torch::nn::Linear(in_features=" << options.in_features()
59           << ", out_features=" << options.out_features()
60           << ", bias=" << options.bias() << ")";
61  }
62  
forward(const Tensor & input)63  Tensor LinearImpl::forward(const Tensor& input) {
64    return F::linear(input, weight, bias);
65  }
66  
67  // ============================================================================
68  
FlattenImpl(const FlattenOptions & options_)69  FlattenImpl::FlattenImpl(const FlattenOptions& options_) : options(options_) {}
70  
reset()71  void FlattenImpl::reset() {}
72  
pretty_print(std::ostream & stream) const73  void FlattenImpl::pretty_print(std::ostream& stream) const {
74    stream << "torch::nn::Flatten(start_dim=" << options.start_dim()
75           << ", end_dim=" << options.end_dim() << ")";
76  }
77  
forward(const Tensor & input)78  Tensor FlattenImpl::forward(const Tensor& input) {
79    return input.flatten(options.start_dim(), options.end_dim());
80  }
81  
82  // ============================================================================
83  
UnflattenImpl(UnflattenOptions options_)84  UnflattenImpl::UnflattenImpl(UnflattenOptions options_)
85      : options(std::move(options_)) {}
86  
reset()87  void UnflattenImpl::reset() {}
88  
pretty_print(std::ostream & stream) const89  void UnflattenImpl::pretty_print(std::ostream& stream) const {
90    auto namedshape = options.namedshape();
91    if (!namedshape.empty()) {
92      stream << "torch::nn::Unflatten(dim=\"" << options.dimname()
93             << "\", unflattened_size={";
94      size_t i = 0;
95      for (; i < namedshape.size() - 1; ++i) {
96        stream << "{\"" << std::get<0>(namedshape[i]) << "\", "
97               << std::get<1>(namedshape[i]) << "}, ";
98      }
99      stream << "{\"" << std::get<0>(namedshape[i]) << "\", "
100             << std::get<1>(namedshape[i]) << "}})";
101    } else {
102      stream << "torch::nn::Unflatten(dim=" << options.dim()
103             << ", unflattened_size={";
104      auto sizes = options.sizes();
105      size_t i = 0;
106      for (; i < sizes.size() - 1; ++i) {
107        stream << sizes[i] << ", ";
108      }
109      stream << sizes[i] << "})";
110    }
111  }
112  
forward(const Tensor & input)113  Tensor UnflattenImpl::forward(const Tensor& input) {
114    auto namedshape = options.namedshape();
115    if (!namedshape.empty()) {
116      auto dimname =
117          torch::Dimname::fromSymbol(torch::Symbol::dimname(options.dimname()));
118      std::vector<int64_t> sizes;
119      std::vector<torch::Dimname> names;
120      for (auto i : namedshape) {
121        names.push_back(
122            torch::Dimname::fromSymbol(torch::Symbol::dimname(std::get<0>(i))));
123        sizes.push_back(std::get<1>(i));
124      }
125      return input.unflatten(dimname, sizes, names);
126    }
127    return input.unflatten(options.dim(), options.sizes());
128  }
129  
130  // ============================================================================
131  
BilinearImpl(const BilinearOptions & options_)132  BilinearImpl::BilinearImpl(const BilinearOptions& options_)
133      : options(options_) {
134    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
135    reset();
136  }
137  
reset()138  void BilinearImpl::reset() {
139    weight = register_parameter(
140        "weight",
141        torch::empty(
142            {options.out_features(),
143             options.in1_features(),
144             options.in2_features()}));
145    if (options.bias()) {
146      bias = register_parameter("bias", torch::empty(options.out_features()));
147    } else {
148      bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false);
149    }
150  
151    reset_parameters();
152  }
153  
reset_parameters()154  void BilinearImpl::reset_parameters() {
155    const auto bound = 1.0 / std::sqrt(weight.size(1));
156    init::uniform_(weight, -bound, bound);
157    if (bias.defined()) {
158      init::uniform_(bias, -bound, bound);
159    }
160  }
161  
pretty_print(std::ostream & stream) const162  void BilinearImpl::pretty_print(std::ostream& stream) const {
163    stream << std::boolalpha
164           << "torch::nn::Bilinear(in1_features=" << options.in1_features()
165           << ", in2_features=" << options.in2_features()
166           << ", out_features=" << options.out_features()
167           << ", bias=" << options.bias() << ")";
168  }
169  
forward(const Tensor & input1,const Tensor & input2)170  Tensor BilinearImpl::forward(const Tensor& input1, const Tensor& input2) {
171    return F::bilinear(input1, input2, weight, bias);
172  }
173  
174  } // namespace nn
175  } // namespace torch
176