1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/Normalization.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/native/cpu/Loops.h>
5
6 #include <ATen/cpu/vec/vec.h>
7
8 #include <ATen/Dispatch.h>
9
10 namespace at::native {
11 namespace {
12
renorm_scale_factor_impl(TensorIteratorBase & iter,double maxnorm)13 void renorm_scale_factor_impl(TensorIteratorBase& iter, double maxnorm) {
14 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "renorm_scale_factor_cpu", [&] {
15 using vec_t = at::vec::Vectorized<scalar_t>;
16 const auto maxnorm_s = static_cast<scalar_t>(maxnorm);
17 const auto maxnorm_v = vec_t(maxnorm_s);
18 const auto eps_v = vec_t(static_cast<scalar_t>(1e-7));
19 const auto one_v = vec_t(1.0);
20 cpu_kernel_vec(
21 iter,
22 [maxnorm_s](scalar_t norm) -> scalar_t {
23 const auto eps = static_cast<scalar_t>(1e-7);
24 return (norm > maxnorm_s) ?
25 maxnorm_s / (norm + eps) : static_cast<scalar_t>(1.0);
26 },
27 [maxnorm_v, eps_v, one_v](vec_t norm) -> vec_t {
28 auto fct = maxnorm_v / (norm + eps_v);
29 return vec_t::blendv(one_v, fct, norm > maxnorm_v);
30 });
31 });
32 }
33
34 } // namespace (anonymous)
35
36 REGISTER_DISPATCH(renorm_scale_factor_stub, &renorm_scale_factor_impl);
37
38 } // namespace at::native
39