xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/loss.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <torch/nn/modules/loss.h>
2  
3  namespace F = torch::nn::functional;
4  
5  namespace torch {
6  namespace nn {
7  
L1LossImpl(L1LossOptions options_)8  L1LossImpl::L1LossImpl(L1LossOptions options_) : options(std::move(options_)) {}
9  
reset()10  void L1LossImpl::reset() {}
11  
pretty_print(std::ostream & stream) const12  void L1LossImpl::pretty_print(std::ostream& stream) const {
13    stream << "torch::nn::L1Loss()";
14  }
15  
forward(const Tensor & input,const Tensor & target)16  Tensor L1LossImpl::forward(const Tensor& input, const Tensor& target) {
17    return F::detail::l1_loss(input, target, options.reduction());
18  }
19  
20  // ============================================================================
21  
KLDivLossImpl(KLDivLossOptions options_)22  KLDivLossImpl::KLDivLossImpl(KLDivLossOptions options_)
23      : options(std::move(options_)) {}
24  
reset()25  void KLDivLossImpl::reset() {}
26  
pretty_print(std::ostream & stream) const27  void KLDivLossImpl::pretty_print(std::ostream& stream) const {
28    stream << "torch::nn::KLDivLoss()";
29  }
30  
forward(const Tensor & input,const Tensor & target)31  Tensor KLDivLossImpl::forward(const Tensor& input, const Tensor& target) {
32    return F::detail::kl_div(
33        input, target, options.reduction(), options.log_target());
34  }
35  
36  // ============================================================================
37  
MSELossImpl(MSELossOptions options_)38  MSELossImpl::MSELossImpl(MSELossOptions options_)
39      : options(std::move(options_)) {}
40  
reset()41  void MSELossImpl::reset() {}
42  
pretty_print(std::ostream & stream) const43  void MSELossImpl::pretty_print(std::ostream& stream) const {
44    stream << "torch::nn::MSELoss()";
45  }
46  
forward(const Tensor & input,const Tensor & target)47  Tensor MSELossImpl::forward(const Tensor& input, const Tensor& target) {
48    return F::detail::mse_loss(input, target, options.reduction());
49  }
50  
51  // ============================================================================
52  
BCELossImpl(BCELossOptions options_)53  BCELossImpl::BCELossImpl(BCELossOptions options_)
54      : options(std::move(options_)) {
55    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
56    reset();
57  }
58  
reset()59  void BCELossImpl::reset() {
60    register_buffer("weight", options.weight());
61  }
62  
pretty_print(std::ostream & stream) const63  void BCELossImpl::pretty_print(std::ostream& stream) const {
64    stream << "torch::nn::BCELoss()";
65  }
66  
forward(const Tensor & input,const Tensor & target)67  Tensor BCELossImpl::forward(const Tensor& input, const Tensor& target) {
68    return F::detail::binary_cross_entropy(
69        input, target, options.weight(), options.reduction());
70  }
71  
72  // ============================================================================
73  
HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_)74  HingeEmbeddingLossImpl::HingeEmbeddingLossImpl(
75      HingeEmbeddingLossOptions options_)
76      : options(std::move(options_)) {}
77  
reset()78  void HingeEmbeddingLossImpl::reset() {}
79  
pretty_print(std::ostream & stream) const80  void HingeEmbeddingLossImpl::pretty_print(std::ostream& stream) const {
81    stream << "torch::nn::HingeEmbeddingLoss(margin=" << options.margin() << ")";
82  }
83  
forward(const Tensor & input,const Tensor & target)84  Tensor HingeEmbeddingLossImpl::forward(
85      const Tensor& input,
86      const Tensor& target) {
87    return F::detail::hinge_embedding_loss(
88        input, target, options.margin(), options.reduction());
89  }
90  
91  // ============================================================================
92  
MultiMarginLossImpl(MultiMarginLossOptions options_)93  MultiMarginLossImpl::MultiMarginLossImpl(MultiMarginLossOptions options_)
94      : options(std::move(options_)) {
95    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
96    reset();
97  }
98  
reset()99  void MultiMarginLossImpl::reset() {
100    TORCH_CHECK(
101        (options.p() == 1) || (options.p() == 2),
102        "only p == 1 and p == 2 supported");
103    TORCH_CHECK(!options.weight().defined() || options.weight().dim() == 1);
104  
105    register_buffer("weight", options.weight());
106  }
107  
pretty_print(std::ostream & stream) const108  void MultiMarginLossImpl::pretty_print(std::ostream& stream) const {
109    stream << "torch::nn::MultiMarginLoss(p=" << options.p()
110           << ", margin=" << options.margin() << ", weight=" << options.weight()
111           << ", reduction=" << enumtype::get_enum_name(options.reduction())
112           << ")";
113  }
114  
forward(const Tensor & input,const Tensor & target)115  Tensor MultiMarginLossImpl::forward(const Tensor& input, const Tensor& target) {
116    return F::detail::multi_margin_loss(
117        input,
118        target,
119        options.p(),
120        options.margin(),
121        options.weight(),
122        options.reduction());
123  }
124  
125  // ============================================================================
126  
CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_)127  CosineEmbeddingLossImpl::CosineEmbeddingLossImpl(
128      CosineEmbeddingLossOptions options_)
129      : options(std::move(options_)) {}
130  
reset()131  void CosineEmbeddingLossImpl::reset() {}
132  
pretty_print(std::ostream & stream) const133  void CosineEmbeddingLossImpl::pretty_print(std::ostream& stream) const {
134    stream << "torch::nn::CosineEmbeddingLoss(margin=" << options.margin() << ")";
135  }
136  
forward(const Tensor & input1,const Tensor & input2,const Tensor & target)137  Tensor CosineEmbeddingLossImpl::forward(
138      const Tensor& input1,
139      const Tensor& input2,
140      const Tensor& target) {
141    return F::detail::cosine_embedding_loss(
142        input1, input2, target, options.margin(), options.reduction());
143  }
144  // ============================================================================
145  
MultiLabelSoftMarginLossImpl(torch::nn::MultiLabelSoftMarginLossOptions options_)146  MultiLabelSoftMarginLossImpl::MultiLabelSoftMarginLossImpl(
147      torch::nn::MultiLabelSoftMarginLossOptions options_)
148      : options(std::move(options_)) {
149    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
150    reset();
151  }
152  
pretty_print(std::ostream & stream) const153  void MultiLabelSoftMarginLossImpl::pretty_print(std::ostream& stream) const {
154    stream << "torch::nn::MultiLabelSoftMarginLoss()";
155  }
156  
reset()157  void MultiLabelSoftMarginLossImpl::reset() {
158    register_buffer("weight", options.weight());
159  }
160  
forward(const Tensor & input,const Tensor & target)161  Tensor MultiLabelSoftMarginLossImpl::forward(
162      const Tensor& input,
163      const Tensor& target) {
164    return F::detail::multilabel_soft_margin_loss(
165        input, target, options.weight(), options.reduction());
166  }
167  
168  // ============================================================================
169  
TripletMarginLossImpl(TripletMarginLossOptions options_)170  TripletMarginLossImpl::TripletMarginLossImpl(TripletMarginLossOptions options_)
171      : options(std::move(options_)) {}
172  
reset()173  void TripletMarginLossImpl::reset() {}
174  
pretty_print(std::ostream & stream) const175  void TripletMarginLossImpl::pretty_print(std::ostream& stream) const {
176    stream << "torch::nn::TripletMarginLoss(margin=" << options.margin()
177           << ", p=" << options.p() << ", eps=" << options.eps() << std::boolalpha
178           << ", swap=" << options.swap() << ")";
179  }
180  
forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)181  Tensor TripletMarginLossImpl::forward(
182      const Tensor& anchor,
183      const Tensor& positive,
184      const Tensor& negative) {
185    return F::detail::triplet_margin_loss(
186        anchor,
187        positive,
188        negative,
189        options.margin(),
190        options.p(),
191        options.eps(),
192        options.swap(),
193        options.reduction());
194  }
195  
196  // ============================================================================
197  
TripletMarginWithDistanceLossImpl(TripletMarginWithDistanceLossOptions options_)198  TripletMarginWithDistanceLossImpl::TripletMarginWithDistanceLossImpl(
199      TripletMarginWithDistanceLossOptions options_)
200      : options(std::move(options_)) {}
201  
reset()202  void TripletMarginWithDistanceLossImpl::reset() {}
203  
pretty_print(std::ostream & stream) const204  void TripletMarginWithDistanceLossImpl::pretty_print(
205      std::ostream& stream) const {
206    stream << "torch::nn::TripletMarginWithDistanceLoss(margin="
207           << options.margin() << std::boolalpha << ", swap=" << options.swap()
208           << ")";
209  }
210  
forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)211  Tensor TripletMarginWithDistanceLossImpl::forward(
212      const Tensor& anchor,
213      const Tensor& positive,
214      const Tensor& negative) {
215    return F::detail::triplet_margin_with_distance_loss(
216        anchor,
217        positive,
218        negative,
219        options.distance_function(),
220        options.margin(),
221        options.swap(),
222        options.reduction());
223  }
224  
225  // ============================================================================
226  
MultiLabelMarginLossImpl(torch::nn::MultiLabelMarginLossOptions options_)227  MultiLabelMarginLossImpl::MultiLabelMarginLossImpl(
228      torch::nn::MultiLabelMarginLossOptions options_)
229      : options(std::move(options_)) {}
230  
reset()231  void MultiLabelMarginLossImpl::reset() {}
232  
pretty_print(std::ostream & stream) const233  void MultiLabelMarginLossImpl::pretty_print(std::ostream& stream) const {
234    stream << "torch::nn::MultiLabelMarginLoss()";
235  }
236  
forward(const Tensor & input,const Tensor & target)237  Tensor MultiLabelMarginLossImpl::forward(
238      const Tensor& input,
239      const Tensor& target) {
240    return F::detail::multilabel_margin_loss(input, target, options.reduction());
241  }
242  
243  // ============================================================================
244  
SoftMarginLossImpl(torch::nn::SoftMarginLossOptions options_)245  SoftMarginLossImpl::SoftMarginLossImpl(
246      torch::nn::SoftMarginLossOptions options_)
247      : options(std::move(options_)) {}
248  
reset()249  void SoftMarginLossImpl::reset() {}
250  
pretty_print(std::ostream & stream) const251  void SoftMarginLossImpl::pretty_print(std::ostream& stream) const {
252    stream << "torch::nn::SoftMarginLoss()";
253  }
254  
forward(const Tensor & input,const Tensor & target)255  Tensor SoftMarginLossImpl::forward(const Tensor& input, const Tensor& target) {
256    return F::detail::soft_margin_loss(input, target, options.reduction());
257  }
258  
259  // ============================================================================
260  
SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_)261  SmoothL1LossImpl::SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_)
262      : options(std::move(options_)) {}
263  
reset()264  void SmoothL1LossImpl::reset() {}
265  
pretty_print(std::ostream & stream) const266  void SmoothL1LossImpl::pretty_print(std::ostream& stream) const {
267    stream << "torch::nn::SmoothL1Loss";
268  }
269  
forward(const Tensor & input,const Tensor & target)270  Tensor SmoothL1LossImpl::forward(const Tensor& input, const Tensor& target) {
271    return F::detail::smooth_l1_loss(
272        input, target, options.reduction(), options.beta());
273  }
274  
275  // ============================================================================
276  
HuberLossImpl(torch::nn::HuberLossOptions options_)277  HuberLossImpl::HuberLossImpl(torch::nn::HuberLossOptions options_)
278      : options(std::move(options_)) {}
279  
reset()280  void HuberLossImpl::reset() {}
281  
pretty_print(std::ostream & stream) const282  void HuberLossImpl::pretty_print(std::ostream& stream) const {
283    stream << "torch::nn::HuberLoss";
284  }
285  
forward(const Tensor & input,const Tensor & target)286  Tensor HuberLossImpl::forward(const Tensor& input, const Tensor& target) {
287    return F::detail::huber_loss(
288        input, target, options.reduction(), options.delta());
289  }
290  
291  // ============================================================================
292  
CTCLossImpl(CTCLossOptions options_)293  CTCLossImpl::CTCLossImpl(CTCLossOptions options_)
294      : options(std::move(options_)) {}
295  
reset()296  void CTCLossImpl::reset() {}
297  
pretty_print(std::ostream & stream) const298  void CTCLossImpl::pretty_print(std::ostream& stream) const {
299    stream << "torch::nn::CTCLoss()";
300  }
301  
forward(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths)302  Tensor CTCLossImpl::forward(
303      const Tensor& log_probs,
304      const Tensor& targets,
305      const Tensor& input_lengths,
306      const Tensor& target_lengths) {
307    return F::detail::ctc_loss(
308        log_probs,
309        targets,
310        input_lengths,
311        target_lengths,
312        options.blank(),
313        options.reduction(),
314        options.zero_infinity());
315  }
316  
317  // ============================================================================
318  
PoissonNLLLossImpl(PoissonNLLLossOptions options_)319  PoissonNLLLossImpl::PoissonNLLLossImpl(PoissonNLLLossOptions options_)
320      : options(std::move(options_)) {}
321  
reset()322  void PoissonNLLLossImpl::reset() {}
323  
pretty_print(std::ostream & stream) const324  void PoissonNLLLossImpl::pretty_print(std::ostream& stream) const {
325    stream << "torch::nn::PoissonNLLLoss()";
326  }
327  
forward(const Tensor & log_input,const Tensor & target)328  Tensor PoissonNLLLossImpl::forward(
329      const Tensor& log_input,
330      const Tensor& target) {
331    return F::detail::poisson_nll_loss(
332        log_input,
333        target,
334        options.log_input(),
335        options.full(),
336        options.eps(),
337        options.reduction());
338  }
339  
340  // ============================================================================
341  
MarginRankingLossImpl(MarginRankingLossOptions options_)342  MarginRankingLossImpl::MarginRankingLossImpl(MarginRankingLossOptions options_)
343      : options(std::move(options_)) {}
344  
reset()345  void MarginRankingLossImpl::reset() {}
346  
pretty_print(std::ostream & stream) const347  void MarginRankingLossImpl::pretty_print(std::ostream& stream) const {
348    stream << "torch::nn::MarginRankingLoss()";
349  }
350  
forward(const Tensor & input1,const Tensor & input2,const Tensor & target)351  Tensor MarginRankingLossImpl::forward(
352      const Tensor& input1,
353      const Tensor& input2,
354      const Tensor& target) {
355    return F::detail::margin_ranking_loss(
356        input1, input2, target, options.margin(), options.reduction());
357  }
358  
359  // ============================================================================
360  
NLLLossImpl(NLLLossOptions options_)361  NLLLossImpl::NLLLossImpl(NLLLossOptions options_)
362      : options(std::move(options_)) {
363    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
364    reset();
365  }
366  
reset()367  void NLLLossImpl::reset() {
368    weight = register_buffer("weight", options.weight());
369  }
370  
pretty_print(std::ostream & stream) const371  void NLLLossImpl::pretty_print(std::ostream& stream) const {
372    stream << "torch::nn::NLLLoss()";
373  }
374  
forward(const Tensor & input,const Tensor & target)375  Tensor NLLLossImpl::forward(const Tensor& input, const Tensor& target) {
376    return F::detail::nll_loss(
377        input, target, weight, options.ignore_index(), options.reduction());
378  }
379  
380  // ============================================================================
381  
CrossEntropyLossImpl(CrossEntropyLossOptions options_)382  CrossEntropyLossImpl::CrossEntropyLossImpl(CrossEntropyLossOptions options_)
383      : options(std::move(options_)) {
384    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
385    reset();
386  }
387  
reset()388  void CrossEntropyLossImpl::reset() {
389    weight = register_buffer("weight", options.weight());
390  }
391  
pretty_print(std::ostream & stream) const392  void CrossEntropyLossImpl::pretty_print(std::ostream& stream) const {
393    stream << "torch::nn::CrossEntropyLoss()";
394  }
395  
forward(const Tensor & input,const Tensor & target)396  Tensor CrossEntropyLossImpl::forward(
397      const Tensor& input,
398      const Tensor& target) {
399    return F::detail::cross_entropy(
400        input,
401        target,
402        weight,
403        options.ignore_index(),
404        options.reduction(),
405        options.label_smoothing());
406  }
407  
408  // ============================================================================
409  
BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_)410  BCEWithLogitsLossImpl::BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_)
411      : options(std::move(options_)) {
412    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
413    reset();
414  }
415  
reset()416  void BCEWithLogitsLossImpl::reset() {
417    weight = register_buffer("weight", options.weight());
418    pos_weight = register_buffer("pos_weight", options.pos_weight());
419  }
420  
pretty_print(std::ostream & stream) const421  void BCEWithLogitsLossImpl::pretty_print(std::ostream& stream) const {
422    stream << "torch::nn::BCEWithLogitsLoss()";
423  }
424  
forward(const Tensor & input,const Tensor & target)425  Tensor BCEWithLogitsLossImpl::forward(
426      const Tensor& input,
427      const Tensor& target) {
428    return F::detail::binary_cross_entropy_with_logits(
429        input,
430        target,
431        options.weight(),
432        options.reduction(),
433        options.pos_weight());
434  }
435  
436  } // namespace nn
437  } // namespace torch
438