xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/RenormKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/Normalization.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/native/cpu/Loops.h>
5 
6 #include <ATen/cpu/vec/vec.h>
7 
8 #include <ATen/Dispatch.h>
9 
10 namespace at::native {
11 namespace {
12 
renorm_scale_factor_impl(TensorIteratorBase & iter,double maxnorm)13 void renorm_scale_factor_impl(TensorIteratorBase& iter, double maxnorm) {
14   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "renorm_scale_factor_cpu", [&] {
15     using vec_t = at::vec::Vectorized<scalar_t>;
16     const auto maxnorm_s = static_cast<scalar_t>(maxnorm);
17     const auto maxnorm_v = vec_t(maxnorm_s);
18     const auto eps_v = vec_t(static_cast<scalar_t>(1e-7));
19     const auto one_v = vec_t(1.0);
20     cpu_kernel_vec(
21       iter,
22       [maxnorm_s](scalar_t norm) -> scalar_t {
23         const auto eps = static_cast<scalar_t>(1e-7);
24         return (norm > maxnorm_s) ?
25             maxnorm_s / (norm + eps) : static_cast<scalar_t>(1.0);
26       },
27       [maxnorm_v, eps_v, one_v](vec_t norm) -> vec_t {
28         auto fct = maxnorm_v / (norm + eps_v);
29         return vec_t::blendv(one_v, fct, norm > maxnorm_v);
30       });
31   });
32 }
33 
34 }  // namespace (anonymous)
35 
36 REGISTER_DISPATCH(renorm_scale_factor_stub, &renorm_scale_factor_impl);
37 
38 }  // namespace at::native
39