1 #include <torch/nn/modules/loss.h> 2 3 namespace F = torch::nn::functional; 4 5 namespace torch { 6 namespace nn { 7 L1LossImpl(L1LossOptions options_)8 L1LossImpl::L1LossImpl(L1LossOptions options_) : options(std::move(options_)) {} 9 reset()10 void L1LossImpl::reset() {} 11 pretty_print(std::ostream & stream) const12 void L1LossImpl::pretty_print(std::ostream& stream) const { 13 stream << "torch::nn::L1Loss()"; 14 } 15 forward(const Tensor & input,const Tensor & target)16 Tensor L1LossImpl::forward(const Tensor& input, const Tensor& target) { 17 return F::detail::l1_loss(input, target, options.reduction()); 18 } 19 20 // ============================================================================ 21 KLDivLossImpl(KLDivLossOptions options_)22 KLDivLossImpl::KLDivLossImpl(KLDivLossOptions options_) 23 : options(std::move(options_)) {} 24 reset()25 void KLDivLossImpl::reset() {} 26 pretty_print(std::ostream & stream) const27 void KLDivLossImpl::pretty_print(std::ostream& stream) const { 28 stream << "torch::nn::KLDivLoss()"; 29 } 30 forward(const Tensor & input,const Tensor & target)31 Tensor KLDivLossImpl::forward(const Tensor& input, const Tensor& target) { 32 return F::detail::kl_div( 33 input, target, options.reduction(), options.log_target()); 34 } 35 36 // ============================================================================ 37 MSELossImpl(MSELossOptions options_)38 MSELossImpl::MSELossImpl(MSELossOptions options_) 39 : options(std::move(options_)) {} 40 reset()41 void MSELossImpl::reset() {} 42 pretty_print(std::ostream & stream) const43 void MSELossImpl::pretty_print(std::ostream& stream) const { 44 stream << "torch::nn::MSELoss()"; 45 } 46 forward(const Tensor & input,const Tensor & target)47 Tensor MSELossImpl::forward(const Tensor& input, const Tensor& target) { 48 return F::detail::mse_loss(input, target, options.reduction()); 49 } 50 51 // ============================================================================ 52 BCELossImpl(BCELossOptions options_)53 BCELossImpl::BCELossImpl(BCELossOptions options_) 54 : options(std::move(options_)) { 55 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 56 reset(); 57 } 58 reset()59 void BCELossImpl::reset() { 60 register_buffer("weight", options.weight()); 61 } 62 pretty_print(std::ostream & stream) const63 void BCELossImpl::pretty_print(std::ostream& stream) const { 64 stream << "torch::nn::BCELoss()"; 65 } 66 forward(const Tensor & input,const Tensor & target)67 Tensor BCELossImpl::forward(const Tensor& input, const Tensor& target) { 68 return F::detail::binary_cross_entropy( 69 input, target, options.weight(), options.reduction()); 70 } 71 72 // ============================================================================ 73 HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_)74 HingeEmbeddingLossImpl::HingeEmbeddingLossImpl( 75 HingeEmbeddingLossOptions options_) 76 : options(std::move(options_)) {} 77 reset()78 void HingeEmbeddingLossImpl::reset() {} 79 pretty_print(std::ostream & stream) const80 void HingeEmbeddingLossImpl::pretty_print(std::ostream& stream) const { 81 stream << "torch::nn::HingeEmbeddingLoss(margin=" << options.margin() << ")"; 82 } 83 forward(const Tensor & input,const Tensor & target)84 Tensor HingeEmbeddingLossImpl::forward( 85 const Tensor& input, 86 const Tensor& target) { 87 return F::detail::hinge_embedding_loss( 88 input, target, options.margin(), options.reduction()); 89 } 90 91 // ============================================================================ 92 MultiMarginLossImpl(MultiMarginLossOptions options_)93 MultiMarginLossImpl::MultiMarginLossImpl(MultiMarginLossOptions options_) 94 : options(std::move(options_)) { 95 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 96 reset(); 97 } 98 reset()99 void MultiMarginLossImpl::reset() { 100 TORCH_CHECK( 101 (options.p() == 1) || (options.p() == 2), 102 "only p == 1 and p == 2 supported"); 103 TORCH_CHECK(!options.weight().defined() || options.weight().dim() == 1); 104 105 register_buffer("weight", options.weight()); 106 } 107 pretty_print(std::ostream & stream) const108 void MultiMarginLossImpl::pretty_print(std::ostream& stream) const { 109 stream << "torch::nn::MultiMarginLoss(p=" << options.p() 110 << ", margin=" << options.margin() << ", weight=" << options.weight() 111 << ", reduction=" << enumtype::get_enum_name(options.reduction()) 112 << ")"; 113 } 114 forward(const Tensor & input,const Tensor & target)115 Tensor MultiMarginLossImpl::forward(const Tensor& input, const Tensor& target) { 116 return F::detail::multi_margin_loss( 117 input, 118 target, 119 options.p(), 120 options.margin(), 121 options.weight(), 122 options.reduction()); 123 } 124 125 // ============================================================================ 126 CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_)127 CosineEmbeddingLossImpl::CosineEmbeddingLossImpl( 128 CosineEmbeddingLossOptions options_) 129 : options(std::move(options_)) {} 130 reset()131 void CosineEmbeddingLossImpl::reset() {} 132 pretty_print(std::ostream & stream) const133 void CosineEmbeddingLossImpl::pretty_print(std::ostream& stream) const { 134 stream << "torch::nn::CosineEmbeddingLoss(margin=" << options.margin() << ")"; 135 } 136 forward(const Tensor & input1,const Tensor & input2,const Tensor & target)137 Tensor CosineEmbeddingLossImpl::forward( 138 const Tensor& input1, 139 const Tensor& input2, 140 const Tensor& target) { 141 return F::detail::cosine_embedding_loss( 142 input1, input2, target, options.margin(), options.reduction()); 143 } 144 // ============================================================================ 145 MultiLabelSoftMarginLossImpl(torch::nn::MultiLabelSoftMarginLossOptions options_)146 MultiLabelSoftMarginLossImpl::MultiLabelSoftMarginLossImpl( 147 torch::nn::MultiLabelSoftMarginLossOptions options_) 148 : options(std::move(options_)) { 149 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 150 reset(); 151 } 152 pretty_print(std::ostream & stream) const153 void MultiLabelSoftMarginLossImpl::pretty_print(std::ostream& stream) const { 154 stream << "torch::nn::MultiLabelSoftMarginLoss()"; 155 } 156 reset()157 void MultiLabelSoftMarginLossImpl::reset() { 158 register_buffer("weight", options.weight()); 159 } 160 forward(const Tensor & input,const Tensor & target)161 Tensor MultiLabelSoftMarginLossImpl::forward( 162 const Tensor& input, 163 const Tensor& target) { 164 return F::detail::multilabel_soft_margin_loss( 165 input, target, options.weight(), options.reduction()); 166 } 167 168 // ============================================================================ 169 TripletMarginLossImpl(TripletMarginLossOptions options_)170 TripletMarginLossImpl::TripletMarginLossImpl(TripletMarginLossOptions options_) 171 : options(std::move(options_)) {} 172 reset()173 void TripletMarginLossImpl::reset() {} 174 pretty_print(std::ostream & stream) const175 void TripletMarginLossImpl::pretty_print(std::ostream& stream) const { 176 stream << "torch::nn::TripletMarginLoss(margin=" << options.margin() 177 << ", p=" << options.p() << ", eps=" << options.eps() << std::boolalpha 178 << ", swap=" << options.swap() << ")"; 179 } 180 forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)181 Tensor TripletMarginLossImpl::forward( 182 const Tensor& anchor, 183 const Tensor& positive, 184 const Tensor& negative) { 185 return F::detail::triplet_margin_loss( 186 anchor, 187 positive, 188 negative, 189 options.margin(), 190 options.p(), 191 options.eps(), 192 options.swap(), 193 options.reduction()); 194 } 195 196 // ============================================================================ 197 TripletMarginWithDistanceLossImpl(TripletMarginWithDistanceLossOptions options_)198 TripletMarginWithDistanceLossImpl::TripletMarginWithDistanceLossImpl( 199 TripletMarginWithDistanceLossOptions options_) 200 : options(std::move(options_)) {} 201 reset()202 void TripletMarginWithDistanceLossImpl::reset() {} 203 pretty_print(std::ostream & stream) const204 void TripletMarginWithDistanceLossImpl::pretty_print( 205 std::ostream& stream) const { 206 stream << "torch::nn::TripletMarginWithDistanceLoss(margin=" 207 << options.margin() << std::boolalpha << ", swap=" << options.swap() 208 << ")"; 209 } 210 forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)211 Tensor TripletMarginWithDistanceLossImpl::forward( 212 const Tensor& anchor, 213 const Tensor& positive, 214 const Tensor& negative) { 215 return F::detail::triplet_margin_with_distance_loss( 216 anchor, 217 positive, 218 negative, 219 options.distance_function(), 220 options.margin(), 221 options.swap(), 222 options.reduction()); 223 } 224 225 // ============================================================================ 226 MultiLabelMarginLossImpl(torch::nn::MultiLabelMarginLossOptions options_)227 MultiLabelMarginLossImpl::MultiLabelMarginLossImpl( 228 torch::nn::MultiLabelMarginLossOptions options_) 229 : options(std::move(options_)) {} 230 reset()231 void MultiLabelMarginLossImpl::reset() {} 232 pretty_print(std::ostream & stream) const233 void MultiLabelMarginLossImpl::pretty_print(std::ostream& stream) const { 234 stream << "torch::nn::MultiLabelMarginLoss()"; 235 } 236 forward(const Tensor & input,const Tensor & target)237 Tensor MultiLabelMarginLossImpl::forward( 238 const Tensor& input, 239 const Tensor& target) { 240 return F::detail::multilabel_margin_loss(input, target, options.reduction()); 241 } 242 243 // ============================================================================ 244 SoftMarginLossImpl(torch::nn::SoftMarginLossOptions options_)245 SoftMarginLossImpl::SoftMarginLossImpl( 246 torch::nn::SoftMarginLossOptions options_) 247 : options(std::move(options_)) {} 248 reset()249 void SoftMarginLossImpl::reset() {} 250 pretty_print(std::ostream & stream) const251 void SoftMarginLossImpl::pretty_print(std::ostream& stream) const { 252 stream << "torch::nn::SoftMarginLoss()"; 253 } 254 forward(const Tensor & input,const Tensor & target)255 Tensor SoftMarginLossImpl::forward(const Tensor& input, const Tensor& target) { 256 return F::detail::soft_margin_loss(input, target, options.reduction()); 257 } 258 259 // ============================================================================ 260 SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_)261 SmoothL1LossImpl::SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_) 262 : options(std::move(options_)) {} 263 reset()264 void SmoothL1LossImpl::reset() {} 265 pretty_print(std::ostream & stream) const266 void SmoothL1LossImpl::pretty_print(std::ostream& stream) const { 267 stream << "torch::nn::SmoothL1Loss"; 268 } 269 forward(const Tensor & input,const Tensor & target)270 Tensor SmoothL1LossImpl::forward(const Tensor& input, const Tensor& target) { 271 return F::detail::smooth_l1_loss( 272 input, target, options.reduction(), options.beta()); 273 } 274 275 // ============================================================================ 276 HuberLossImpl(torch::nn::HuberLossOptions options_)277 HuberLossImpl::HuberLossImpl(torch::nn::HuberLossOptions options_) 278 : options(std::move(options_)) {} 279 reset()280 void HuberLossImpl::reset() {} 281 pretty_print(std::ostream & stream) const282 void HuberLossImpl::pretty_print(std::ostream& stream) const { 283 stream << "torch::nn::HuberLoss"; 284 } 285 forward(const Tensor & input,const Tensor & target)286 Tensor HuberLossImpl::forward(const Tensor& input, const Tensor& target) { 287 return F::detail::huber_loss( 288 input, target, options.reduction(), options.delta()); 289 } 290 291 // ============================================================================ 292 CTCLossImpl(CTCLossOptions options_)293 CTCLossImpl::CTCLossImpl(CTCLossOptions options_) 294 : options(std::move(options_)) {} 295 reset()296 void CTCLossImpl::reset() {} 297 pretty_print(std::ostream & stream) const298 void CTCLossImpl::pretty_print(std::ostream& stream) const { 299 stream << "torch::nn::CTCLoss()"; 300 } 301 forward(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths)302 Tensor CTCLossImpl::forward( 303 const Tensor& log_probs, 304 const Tensor& targets, 305 const Tensor& input_lengths, 306 const Tensor& target_lengths) { 307 return F::detail::ctc_loss( 308 log_probs, 309 targets, 310 input_lengths, 311 target_lengths, 312 options.blank(), 313 options.reduction(), 314 options.zero_infinity()); 315 } 316 317 // ============================================================================ 318 PoissonNLLLossImpl(PoissonNLLLossOptions options_)319 PoissonNLLLossImpl::PoissonNLLLossImpl(PoissonNLLLossOptions options_) 320 : options(std::move(options_)) {} 321 reset()322 void PoissonNLLLossImpl::reset() {} 323 pretty_print(std::ostream & stream) const324 void PoissonNLLLossImpl::pretty_print(std::ostream& stream) const { 325 stream << "torch::nn::PoissonNLLLoss()"; 326 } 327 forward(const Tensor & log_input,const Tensor & target)328 Tensor PoissonNLLLossImpl::forward( 329 const Tensor& log_input, 330 const Tensor& target) { 331 return F::detail::poisson_nll_loss( 332 log_input, 333 target, 334 options.log_input(), 335 options.full(), 336 options.eps(), 337 options.reduction()); 338 } 339 340 // ============================================================================ 341 MarginRankingLossImpl(MarginRankingLossOptions options_)342 MarginRankingLossImpl::MarginRankingLossImpl(MarginRankingLossOptions options_) 343 : options(std::move(options_)) {} 344 reset()345 void MarginRankingLossImpl::reset() {} 346 pretty_print(std::ostream & stream) const347 void MarginRankingLossImpl::pretty_print(std::ostream& stream) const { 348 stream << "torch::nn::MarginRankingLoss()"; 349 } 350 forward(const Tensor & input1,const Tensor & input2,const Tensor & target)351 Tensor MarginRankingLossImpl::forward( 352 const Tensor& input1, 353 const Tensor& input2, 354 const Tensor& target) { 355 return F::detail::margin_ranking_loss( 356 input1, input2, target, options.margin(), options.reduction()); 357 } 358 359 // ============================================================================ 360 NLLLossImpl(NLLLossOptions options_)361 NLLLossImpl::NLLLossImpl(NLLLossOptions options_) 362 : options(std::move(options_)) { 363 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 364 reset(); 365 } 366 reset()367 void NLLLossImpl::reset() { 368 weight = register_buffer("weight", options.weight()); 369 } 370 pretty_print(std::ostream & stream) const371 void NLLLossImpl::pretty_print(std::ostream& stream) const { 372 stream << "torch::nn::NLLLoss()"; 373 } 374 forward(const Tensor & input,const Tensor & target)375 Tensor NLLLossImpl::forward(const Tensor& input, const Tensor& target) { 376 return F::detail::nll_loss( 377 input, target, weight, options.ignore_index(), options.reduction()); 378 } 379 380 // ============================================================================ 381 CrossEntropyLossImpl(CrossEntropyLossOptions options_)382 CrossEntropyLossImpl::CrossEntropyLossImpl(CrossEntropyLossOptions options_) 383 : options(std::move(options_)) { 384 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 385 reset(); 386 } 387 reset()388 void CrossEntropyLossImpl::reset() { 389 weight = register_buffer("weight", options.weight()); 390 } 391 pretty_print(std::ostream & stream) const392 void CrossEntropyLossImpl::pretty_print(std::ostream& stream) const { 393 stream << "torch::nn::CrossEntropyLoss()"; 394 } 395 forward(const Tensor & input,const Tensor & target)396 Tensor CrossEntropyLossImpl::forward( 397 const Tensor& input, 398 const Tensor& target) { 399 return F::detail::cross_entropy( 400 input, 401 target, 402 weight, 403 options.ignore_index(), 404 options.reduction(), 405 options.label_smoothing()); 406 } 407 408 // ============================================================================ 409 BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_)410 BCEWithLogitsLossImpl::BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_) 411 : options(std::move(options_)) { 412 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 413 reset(); 414 } 415 reset()416 void BCEWithLogitsLossImpl::reset() { 417 weight = register_buffer("weight", options.weight()); 418 pos_weight = register_buffer("pos_weight", options.pos_weight()); 419 } 420 pretty_print(std::ostream & stream) const421 void BCEWithLogitsLossImpl::pretty_print(std::ostream& stream) const { 422 stream << "torch::nn::BCEWithLogitsLoss()"; 423 } 424 forward(const Tensor & input,const Tensor & target)425 Tensor BCEWithLogitsLossImpl::forward( 426 const Tensor& input, 427 const Tensor& target) { 428 return F::detail::binary_cross_entropy_with_logits( 429 input, 430 target, 431 options.weight(), 432 options.reduction(), 433 options.pos_weight()); 434 } 435 436 } // namespace nn 437 } // namespace torch 438