1 #pragma once 2 3 #include <algorithm> 4 #include <atomic> 5 #include <cmath> 6 #include <cstdlib> 7 #include <limits> 8 #include <memory> 9 #include <optional> 10 #include <map> 11 #include <omp.h> 12 13 // WARNING: be extra careful when including more ATen/c10 header files here! 14 // Because AOTInductor generated code will copy-paste this cpp_prefix.h for 15 // the CPU backend, we have to make sure the used headers are implemented 16 // in a header-only way, i.e. all the function and class definitions are 17 // in .h files instead of .cpp files, to avoid ABI backward-compatiblity breakage. 18 19 #include <ATen/NumericUtils.h> 20 #include <ATen/core/PhiloxRNGEngine.h> 21 22 #include <c10/util/Float8_e4m3fn.h> 23 #include <c10/util/Float8_e5m2.h> 24 #include <c10/util/BFloat16.h> 25 #include <c10/util/BFloat16-math.h> 26 #include <c10/util/generic_math.h> 27 #include <c10/util/Half.h> 28 #include <c10/util/TypeCast.h> 29 30 #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) 31 #define INDUCTOR_USE_VECTOR_TYPES() 1 32 #else 33 #define INDUCTOR_USE_VECTOR_TYPES() 0 34 #endif 35 36 #if INDUCTOR_USE_VECTOR_TYPES() 37 #include <ATen/cpu/vec/functional.h> 38 #include <ATen/cpu/vec/vec.h> 39 #else 40 // For calc_erfinv 41 #include <ATen/native/Math.h> 42 #endif 43 44 typedef at::Half half; 45 typedef at::BFloat16 bfloat16; 46 47 typedef at::Float8_e4m3fn float8_e4m3fn; 48 typedef at::Float8_e5m2 float8_e5m2; 49 50 template <typename T> 51 struct Welford { 52 T mean = T(0); 53 T m2 = T(0); 54 // Use weight for tail cases since the index of each element in the vec may be 55 // different. A single index can not express masked welford reduction. 56 T weight = T(0); 57 uint64_t index = 0; 58 }; 59 60 61 template <typename T> 62 struct IsVecType: std::false_type {}; 63 64 #if INDUCTOR_USE_VECTOR_TYPES() 65 template <typename T> 66 struct IsVecType<at::vec::Vectorized<T>>: std::true_type {}; 67 #endif 68 69 template <typename T> 70 struct WeightRecp { 71 using scalar_t = typename T::value_type; 72 std::vector<scalar_t> weight_recps; 73 WeightRecp(uint64_t N) { 74 weight_recps.reserve(N); 75 for (const auto i : c10::irange(N)) { 76 weight_recps.push_back( 77 scalar_t(static_cast<double>(1) / static_cast<double>(i + 1))); 78 } 79 } 80 }; 81 82 template <typename T> 83 Welford<T> welford_combine(const Welford<T>& a, const Welford<T>& b, bool use_index=false) { 84 if (a.index == 0) { 85 return b; 86 } 87 if (b.index == 0) { 88 return a; 89 } 90 auto delta = b.mean - a.mean; 91 auto a_weight = use_index ? T(a.index) : a.weight; 92 auto b_weight = use_index ? T(b.index) : b.weight; 93 auto new_weight = a_weight + b_weight; 94 auto new_index = a.index + b.index; 95 auto wb_over_w = b_weight / new_weight; 96 if constexpr (IsVecType<T>::value) { 97 // Guard against division by zero 98 wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0)); 99 } 100 auto result = Welford<T>{ 101 a.mean + delta * wb_over_w, 102 a.m2 + b.m2 + delta * delta * a_weight * wb_over_w, 103 new_weight, 104 new_index 105 }; 106 return result; 107 } 108 109 template <typename T> 110 Welford<T> welford_combine(const Welford<T>& acc, const T& data, const WeightRecp<T>* w=nullptr) { 111 // Add a single data point 112 uint64_t new_index = acc.index + 1; 113 auto new_weight = acc.weight + T(1); 114 auto delta = data - acc.mean; 115 T new_mean; 116 if constexpr (!IsVecType<T>::value) { 117 new_mean = acc.mean + delta / new_weight; 118 } else { 119 // use new_index to fecth 1 / new_weight to avoid divisions 120 new_mean = acc.mean + 121 ((w == nullptr || acc.index >= w->weight_recps.size()) 122 ? delta / new_weight 123 : delta * T(w->weight_recps[acc.index])); 124 } 125 auto new_delta = data - new_mean; 126 auto result = Welford<T>{ 127 new_mean, 128 acc.m2 + delta * new_delta, 129 new_weight, 130 new_index 131 }; 132 return result; 133 } 134 135 template <typename T> 136 struct IndexValue { 137 int64_t index; 138 T value; 139 IndexValue(int64_t idx, T val) :index(idx), value(val) {}; 140 IndexValue() {}; 141 }; 142 143 #if INDUCTOR_USE_VECTOR_TYPES() 144 template <typename T> 145 Welford<T> welford_combine(const Welford<T>& acc, const T& data, const int64_t tail_size, const WeightRecp<T>* w=nullptr) { 146 auto out = welford_combine(acc, data, w); 147 return Welford<T>{ 148 T::set(acc.mean, out.mean, tail_size), 149 T::set(acc.m2, out.m2, tail_size), 150 T::set(acc.weight, out.weight, tail_size), 151 out.index 152 }; 153 } 154 155 template <typename T> 156 T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) { 157 auto out = at::vec::maximum(a, b); 158 return T::set(a, out, tail_size); 159 } 160 161 template <typename T> 162 T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) { 163 auto out = at::vec::minimum(a, b); 164 return T::set(a, out, tail_size); 165 } 166 167 template <typename T> 168 T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) { 169 auto out = a + b; 170 return T::set(a, out, tail_size); 171 } 172 173 template <typename T> 174 T prod_masked_reduce(const T& a, const T& b, const int64_t tail_size) { 175 auto out = a * b; 176 return T::set(a, out, tail_size); 177 } 178 179 template <typename T> 180 T xor_sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) { 181 auto out = a ^ b; 182 return T::set(a, out, tail_size); 183 } 184 #endif 185 186 // Refer to https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/ 187 // aten/src/ATen/native/SharedReduceOps.h#L419-L445 188 template <typename scalar_t> 189 inline bool greater_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { 190 // If (a == b), then choose the one with lower idx, else max(a, b) 191 if (at::_isnan(a)) { 192 if (at::_isnan(b)) { 193 return idx_a < idx_b; 194 } 195 return true; 196 } 197 return (a == b) ? idx_a < idx_b : (a > b); 198 } 199 200 template <typename scalar_t> 201 inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { 202 // If (a == b), then choose the one with lower idx, else min(a, b) 203 if (at::_isnan(a)) { 204 if (at::_isnan(b)) { 205 return idx_a < idx_b; 206 } 207 return true; 208 } 209 return (a == b) ? idx_a < idx_b : (a < b); 210 } 211 212 template <typename T> 213 inline IndexValue<T>& argmin_combine(IndexValue<T>& a, T next_value, int64_t next_index){ 214 if(!(less_or_nan(a.value, next_value, a.index, next_index))){ 215 a.value = next_value; 216 a.index = next_index; 217 } 218 return a; 219 } 220 template <typename T> 221 inline IndexValue<T>& argmax_combine(IndexValue<T>& a, T next_value, int64_t next_index){ 222 if(!(greater_or_nan(a.value, next_value, a.index, next_index))){ 223 a.value = next_value; 224 a.index = next_index; 225 } 226 return a; 227 } 228 template <typename T> 229 inline IndexValue<T>& argmin_combine(IndexValue<T>& a, const IndexValue<T>& next){ 230 return argmin_combine(a, next.value, next.index); 231 } 232 template <typename T> 233 inline IndexValue<T>& argmax_combine(IndexValue<T>& a, const IndexValue<T>& next){ 234 return argmax_combine(a, next.value, next.index); 235 } 236 237 #if INDUCTOR_USE_VECTOR_TYPES() 238 239 template <typename scalar_t> 240 inline at::vec::Vectorized<scalar_t> div_floor_floating_vec( 241 const at::vec::Vectorized<scalar_t>& a, 242 const at::vec::Vectorized<scalar_t>& b) { 243 using vec_t = at::vec::Vectorized<scalar_t>; 244 const auto basic_div = a / b; 245 vec_t inf(std::numeric_limits<scalar_t>::infinity()); 246 auto mod = a.fmod(b); 247 // Fixup for a case that isn't properly handled by Sleef_fmod 248 auto floor = vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf)); 249 auto div = floor / b; 250 const auto zero = vec_t(0); 251 auto mask = (mod != zero) & ((b < zero) ^ (mod < zero)); 252 const auto one = vec_t(1); 253 div = vec_t::blendv(div, div - one, mask); 254 auto floordiv = div.floor(); 255 mask = (div - floordiv) > vec_t(0.5); 256 floordiv = vec_t::blendv(floordiv, floordiv + one, mask); 257 floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero); 258 floordiv = vec_t::blendv(floordiv, basic_div, b == zero); 259 return floordiv; 260 }; 261 262 template <typename scalar_t, int N> 263 inline at::vec::VectorizedN<scalar_t, N> div_floor_floating_vec( 264 const at::vec::VectorizedN<scalar_t, N>& a, 265 const at::vec::VectorizedN<scalar_t, N>& b) { 266 at::vec::VectorizedN<scalar_t, N> result; 267 #ifndef _MSC_VER 268 #pragma unroll 269 #endif 270 for (int i = 0; i < N; ++i) { 271 result[i] = div_floor_floating_vec(a[i], b[i]); 272 } 273 return result; 274 } 275 276 template <typename T, int NV, int NI> 277 struct IndexValueVec { 278 at::vec::VectorizedN<T, NV> value; 279 at::vec::VectorizedN<int64_t, NI> index; 280 281 IndexValueVec(const T _value) { 282 value = at::vec::VectorizedN<T, NV>(_value); 283 index = at::vec::VectorizedN<int64_t, NI>(0); 284 }; 285 286 IndexValueVec() {}; 287 }; 288 289 290 template <typename T, int NV, int NI, 291 typename std::enable_if_t<at::vec::is_floating_point_v<T>, int> = 0> 292 at::vec::VecMask<int64_t, NI> inline get_mask_for_argmin_argmax( 293 const at::vec::VecMask<T, NV>& vmask, 294 const IndexValueVec<T, NV, NI>& a, 295 const at::vec::VectorizedN<T, NV>& value, 296 const at::vec::VectorizedN<int64_t, NI>& index 297 ){ 298 /* 299 vec impl for less_or_nan and greater_or_nan 300 example for argmin: 301 a.value = [NaN, NaN, 0, 2, 1, 0] 302 value = [NaN, 0, 0, 1, 2, NaN] 303 vmask = [false, false, false, false, true, false] 304 all_nan_or_equal = [true, false, true, false, false, false] 305 imask = [a.index[0] < index[0], ..., a.index[-1] < index[-1]] 306 iv_mask = blendv (vmask, imask, all_nan_or_equal) 307 [a.index[0] < index[0], false, a.index[2] < index[2], false, true, false] 308 a_nan_b_not: [false, false, false, false, false, true] 309 mask = iv_mask | a_nan_b_not 310 [a.index[0] < index[0], false, a.index[2] < index[2], false, true, true] 311 */ 312 using v_t = at::vec::VecMask<T, NV>; 313 using i_t = at::vec::VecMask<int64_t, NI>; 314 i_t vmask_itype = vmask.template cast<int64_t, NI>(); 315 // use itype here since there is vec impl for operator~ for itype 316 // while there may not vec impl for vtype 317 v_t isnan_a = a.value.isnan(); 318 i_t isnan_a_itype = isnan_a.template cast<int64_t, NI>(); 319 v_t isnan_b = value.isnan(); 320 i_t isnan_b_type = isnan_b.template cast<int64_t, NI>(); 321 i_t all_nan_mask = isnan_a_itype & isnan_b_type; 322 v_t equal_mask = (a.value == value); 323 i_t equal_mask_itype = equal_mask.template cast<int64_t, NI>(); 324 i_t all_nan_or_equal = all_nan_mask | equal_mask_itype; 325 i_t imask(a.index < index); 326 i_t iv_mask = i_t::blendv(vmask_itype, imask, all_nan_or_equal); 327 i_t isnan_a_notnan_b = isnan_a_itype & (~isnan_b_type); 328 return iv_mask | isnan_a_notnan_b; 329 } 330 331 template <typename T, int NV, int NI, 332 typename std::enable_if_t<!at::vec::is_floating_point_v<T>, int> = 0> 333 at::vec::VecMask<int64_t, NI> inline get_mask_for_argmin_argmax( 334 const at::vec::VecMask<T, NV>& vmask, 335 const IndexValueVec<T, NV, NI>& a, 336 const at::vec::VectorizedN<T, NV>& value, 337 const at::vec::VectorizedN<int64_t, NI>& index 338 ){ 339 using v_t = at::vec::VecMask<T, NV>; 340 using i_t = at::vec::VecMask<int64_t, NI>; 341 i_t vmask_itype = vmask.template cast<int64_t, NI>(); 342 v_t equal_mask = (a.value == value); 343 i_t equal_mask_itype = equal_mask.template cast<int64_t, NI>(); 344 i_t imask(a.index < index); 345 return i_t::blendv(vmask_itype, imask, equal_mask_itype); 346 } 347 348 349 template <typename T, int NV, int NI> 350 inline IndexValueVec<T, NV, NI>& argmin_vec_impl(IndexValueVec<T, NV, NI>& a, at::vec::VectorizedN<T, NV> value, at::vec::VectorizedN<int64_t, NI> index, std::optional<int64_t> tail_size){ 351 at::vec::VecMask<T, NV> vmask(a.value < value); 352 at::vec::VecMask<int64_t, NI> final_mask = get_mask_for_argmin_argmax<T, NV, NI>(vmask, a, value, index); 353 if (tail_size.has_value()) { 354 a.value = at::vec::VectorizedN<T, NV>::set(a.value, at::vec::minimum(a.value, value), tail_size.value()); 355 a.index = at::vec::VectorizedN<int64_t, NI>::set(a.index, at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask), tail_size.value()); 356 } else { 357 a.value = at::vec::minimum(a.value, value); 358 a.index = at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask); 359 } 360 return a; 361 } 362 363 template <typename T, int NV, int NI> 364 inline IndexValueVec<T, NV, NI>& argmax_vec_impl(IndexValueVec<T, NV, NI>& a, at::vec::VectorizedN<T, NV> value, at::vec::VectorizedN<int64_t, NI> index, std::optional<int64_t> tail_size){ 365 at::vec::VecMask<T, NV> vmask(a.value > value); 366 at::vec::VecMask<int64_t, NI> final_mask = get_mask_for_argmin_argmax<T, NV, NI>(vmask, a, value, index); 367 if (tail_size.has_value()) { 368 a.value = at::vec::VectorizedN<T, NV>::set(a.value, at::vec::maximum(a.value, value), tail_size.value()); 369 a.index = at::vec::VectorizedN<int64_t, NI>::set(a.index, at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask), tail_size.value()); 370 } else { 371 a.value = at::vec::maximum(a.value, value); 372 a.index = at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask); 373 } 374 return a; 375 } 376 377 template <typename T, int NI, bool horizontal> 378 inline at::vec::VectorizedN<int64_t, NI> create_index(int64_t next_index){ 379 at::vec::VectorizedN<int64_t, NI> next_idx; 380 if constexpr (horizontal) { 381 next_idx = at::vec::VectorizedN<int64_t, NI>::arange(next_index, 1); 382 } else { 383 next_idx = at::vec::VectorizedN<int64_t, NI>(next_index); 384 } 385 return next_idx; 386 } 387 388 template <typename T, int NV, int NI, bool horizontal> 389 inline IndexValueVec<T, NV, NI>& argmin_combine_vec(IndexValueVec<T, NV, NI>& a, at::vec::VectorizedN<T, NV> next_value, int64_t next_index, std::optional<int64_t> tail_size = std::nullopt){ 390 auto next_idx = create_index<T, NI, horizontal>(next_index); 391 return argmin_vec_impl(a, next_value, next_idx, tail_size); 392 } 393 394 template <typename T, int NV, int NI, bool horizontal> 395 inline IndexValueVec<T, NV, NI>& argmax_combine_vec(IndexValueVec<T, NV, NI>& a, at::vec::VectorizedN<T, NV> next_value, int64_t next_index, std::optional<int64_t> tail_size = std::nullopt){ 396 auto next_idx = create_index<T, NI, horizontal>(next_index); 397 return argmax_vec_impl(a, next_value, next_idx, tail_size); 398 } 399 400 template <typename T, int NV, int NI> 401 inline IndexValue<T> argmin_vec_reduce_all(const IndexValueVec<T, NV, NI>& vec){ 402 constexpr int len = at::vec::VectorizedN<T, NV>::size(); 403 __at_align__ T tmpval[len]; 404 __at_align__ int64_t tmpidx[len]; 405 vec.value.store(tmpval); 406 vec.index.store(tmpidx); 407 IndexValue res = IndexValue<T>(tmpidx[0], tmpval[0]); 408 for (int i = 1; i < len; i++){ 409 res = argmin_combine(res, tmpval[i], tmpidx[i]); 410 } 411 return res; 412 } 413 414 template <typename T, int NV, int NI> 415 inline IndexValue<T> argmax_vec_reduce_all(const IndexValueVec<T, NV, NI>& vec){ 416 constexpr int len = at::vec::VectorizedN<T, NV>::size(); 417 __at_align__ T tmpval[len]; 418 __at_align__ int64_t tmpidx[len]; 419 vec.value.store(tmpval); 420 vec.index.store(tmpidx); 421 IndexValue res = IndexValue<T>(tmpidx[0], tmpval[0]); 422 for (int i = 1; i < len; i++){ 423 res = argmax_combine(res, tmpval[i], tmpidx[i]); 424 } 425 return res; 426 } 427 428 template <typename T, int NV, int NI> 429 inline IndexValueVec<T, NV, NI>& argmin_combine_vec(IndexValueVec<T, NV, NI>& vec_a, const IndexValueVec<T, NV, NI>& vec_b, std::optional<int64_t> tail_size = std::nullopt){ 430 return argmin_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size); 431 } 432 433 template <typename T, int NV, int NI> 434 inline IndexValueVec<T, NV, NI>& argmax_combine_vec(IndexValueVec<T, NV, NI>& vec_a, const IndexValueVec<T, NV, NI>& vec_b, std::optional<int64_t> tail_size = std::nullopt){ 435 return argmax_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size); 436 } 437 438 template <typename scalar_t> 439 inline at::vec::Vectorized<scalar_t> vec_shuffle_down(at::vec::Vectorized<scalar_t> x, size_t n) { 440 using Vec = at::vec::Vectorized<scalar_t>; 441 alignas(alignof(Vec)) scalar_t array[Vec::size()]; 442 x.store(array); 443 for (size_t i = 0; i + n < Vec::size(); i += 2 * n) { 444 array[i] = array[i + n]; 445 } 446 return Vec::loadu(array); 447 } 448 449 #ifdef CPU_CAPABILITY_AVX2 450 inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x, size_t n) { 451 using vec_t = at::vec::Vectorized<float>; 452 #define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w) 453 switch (n) { 454 case 1: 455 return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3))); 456 case 2: 457 return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2))); 458 case 4: 459 return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1))); 460 } 461 TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n); 462 } 463 #endif 464 465 #ifdef CPU_CAPABILITY_AVX512 466 inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x, size_t n) { 467 using vec_t = at::vec::Vectorized<float>; 468 #define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w) 469 switch (n) { 470 case 1: 471 return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3))); 472 case 2: 473 return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2))); 474 case 4: 475 return vec_t(_mm512_permutexvar_ps( 476 _mm512_set_epi32( 477 12, 12, 12, 12, 12, 12, 12, 12, 4, 4, 4, 4, 4, 4, 4, 4), 478 x)); 479 case 8: 480 return vec_t(_mm512_permutexvar_ps( 481 _mm512_set_epi32(8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8), x)); 482 } 483 TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n); 484 } 485 #endif 486 487 template <typename scalar_t> 488 Welford<scalar_t> welford_vec_reduce_all(Welford<at::vec::Vectorized<scalar_t>> acc) { 489 using Vec = at::vec::Vectorized<scalar_t>; 490 Welford<scalar_t> result; 491 if (acc.index == 0) { 492 return result; 493 } 494 // if all values of acc.weight are same as index, 495 // use index to reduce to save the overhead of vec_shuffle_down for acc.weight 496 bool use_index = (acc.weight - Vec(acc.index)).zero_mask() == static_cast<int>((1 << Vec::size()) - 1); 497 for (size_t n = 1; n < Vec::size(); n *= 2) { 498 auto shuffled = Welford<Vec>{ 499 vec_shuffle_down(acc.mean, n), 500 vec_shuffle_down(acc.m2, n), 501 use_index ? Vec(0) : vec_shuffle_down(acc.weight, n), 502 acc.index}; 503 acc = welford_combine(acc, shuffled, use_index); 504 } 505 506 alignas(alignof(Vec)) scalar_t array[Vec::size()]; 507 acc.mean.store(array); 508 result.mean = array[0]; 509 510 acc.m2.store(array); 511 result.m2 = array[0]; 512 513 acc.weight.store(array); 514 result.weight = array[0]; 515 result.index = result.weight; 516 517 return result; 518 } 519 520 template <typename scalar_t> 521 Welford<scalar_t> welford_vec_reduce_all(Welford<at::vec::VectorizedN<scalar_t, 2>> acc) { 522 auto Welford0 = Welford<at::vec::Vectorized<scalar_t>>{ 523 acc.mean[0], 524 acc.m2[0], 525 acc.weight[0], 526 acc.index 527 }; 528 auto Welford1 = Welford<at::vec::Vectorized<scalar_t>>{ 529 acc.mean[1], 530 acc.m2[1], 531 acc.weight[1], 532 acc.index 533 }; 534 return welford_vec_reduce_all(welford_combine(Welford0, Welford1)); 535 } 536 #endif 537 538 539 template <typename T, typename U> inline typename std::common_type<T, U>::type mod(T a, U b) { return a % b; } 540 template <> inline float mod(float a, float b) { return std::fmod(a, b); } 541 template <> inline double mod(double a, double b) { return std::fmod(a, b); } 542 543 template <typename scalar_t> 544 inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { 545 if (at::_isnan(a)) { 546 return a; 547 } 548 return a > b ? a : b; 549 } 550 551 template <typename scalar_t> 552 inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) { 553 if (at::_isnan(a)) { 554 return a; 555 } 556 return a < b ? a : b; 557 } 558 559 constexpr float uint32_to_uniform_float(uint32_t value) { 560 // maximum value such that `MAX_INT * scale < 1.0` (with float rounding) 561 constexpr float scale = 4.6566127342e-10; 562 return static_cast<float>(value & 0x7FFFFFFF) * scale; 563 } 564 565 float normalized_rand_cpu(uint32_t seed, uint32_t offset) { 566 return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)()); 567 } 568 569 float randn_cpu(uint32_t seed, uint32_t offset) { 570 at::Philox4_32 engine(seed, 0, offset); 571 return engine.randn(10); 572 } 573 574 int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_t high) { 575 auto gen = at::Philox4_32(seed, 0, offset); 576 uint64_t r0 = gen(); 577 uint64_t r1 = gen(); 578 uint64_t result = r0 | (r1 << 32); 579 return static_cast<int64_t>(result % (high - low)) + low; 580 } 581 582 template <typename T> struct AsIntegerType { typedef T type; }; 583 template <> struct AsIntegerType<float> { typedef uint32_t type; }; 584 template <> struct AsIntegerType<double> { typedef uint64_t type; }; 585 template <> struct AsIntegerType<bfloat16> { typedef uint16_t type; }; 586 587 template <typename T> 588 typename std::enable_if_t<!std::is_reduced_floating_point_v<T>, T> 589 inline fetch_value(volatile T *addr) { 590 return *addr; 591 } 592 593 template <typename T> 594 typename std::enable_if_t<std::is_reduced_floating_point_v<T>, T> 595 inline fetch_value(volatile T *addr) { 596 return T(addr->x, T::from_bits()); 597 } 598 599 template <typename T> 600 typename std::enable_if_t<!std::is_integral_v<T>> 601 atomic_add(volatile T *addr, T offset) { 602 typedef typename AsIntegerType<T>::type alt_type; 603 604 static_assert(sizeof(std::atomic<alt_type>) == sizeof(T), 605 "std::atomic issue"); 606 607 alt_type expected; 608 609 alt_type desired; 610 611 std::atomic<alt_type> *atomic_addr = (std::atomic<alt_type> *)addr; 612 do { 613 T val = fetch_value(addr); 614 reinterpret_cast<T *>(&expected)[0] = val; 615 reinterpret_cast<T *>(&desired)[0] = val + offset; 616 } while (!atomic_addr->compare_exchange_weak(expected, desired, 617 std::memory_order_relaxed)); 618 } 619 620 // Since C++20 float is supported by fetch_add, but the performance may not 621 // better than compare_exchange_weak, which can be checked by microbenchmark 622 // inductor_cpu_atomic.py 623 template <typename T> 624 typename std::enable_if_t<std::is_integral_v<T>> 625 atomic_add(volatile T *addr, T offset) { 626 static_assert(sizeof(std::atomic<T>) == sizeof(T), 627 "std::atomic issue"); 628 std::atomic<T> *atomic_addr = (std::atomic<T> *)addr; 629 atomic_addr->fetch_add(offset, std::memory_order_relaxed); 630 } 631 632 #if INDUCTOR_USE_VECTOR_TYPES() 633 template <typename T, int NI, int NV> 634 void atomic_add_vec(T *addr, at::vec::VectorizedN<int64_t, NI> index, at::vec::VectorizedN<T, NV> offset) { 635 constexpr int len = at::vec::VectorizedN<int64_t, NI>::size(); 636 static_assert(len <= at::vec::VectorizedN<T, NV>::size()); 637 __at_align__ std::array<T, len> tmpbuf; 638 __at_align__ std::array<int64_t, len> tmpidx; 639 offset.store(tmpbuf.data()); 640 index.store(tmpidx.data()); 641 for (int i = 0; i < len; i++){ 642 atomic_add(addr + tmpidx[i], tmpbuf[i]); 643 } 644 } 645 #endif 646 647 std::tuple<std::shared_ptr<int64_t[]>, int> _get_factors(int64_t number) { 648 int count = 0; 649 for (int64_t i = std::sqrt(number); i > 0; --i) { 650 if (number % i == 0) { 651 count += 2; 652 } 653 } 654 auto factors = std::shared_ptr<int64_t[]>(new int64_t[count]); 655 int index = 0; 656 for (int64_t i = std::sqrt(number); i > 0; --i) { 657 if (number % i == 0) { 658 factors[index++] = number / i; 659 factors[index++] = i; 660 } 661 } 662 return std::make_tuple(factors, count); 663 } 664 665 std::tuple<std::shared_ptr<int64_t[]>, int> get_factors(int64_t number) { 666 thread_local std::map<int64_t, std::tuple<std::shared_ptr<int64_t[]>, int>> cache; 667 auto it = cache.find(number); 668 if (it != cache.end()) { 669 return it->second; 670 } else { 671 auto factors = _get_factors(number); 672 cache[number] = factors; 673 return factors; 674 } 675 } 676 677 void _mm_get_thread_blocking( 678 int num_threads, 679 int max_k_slices, 680 int64_t M, 681 int64_t N, 682 int64_t K, 683 int64_t Mr, 684 int64_t Nr, 685 int64_t Kr, 686 int64_t& Mt, 687 int64_t& Nt, 688 int64_t& Kt) { 689 // see NOTE [Thread blocking in Cpp GEMM] for heuristics 690 Mt = Nt = Kt = 0; 691 692 auto get_blocking = [](int64_t m_factor, 693 int64_t n_factor, 694 int64_t k_factor, 695 int64_t m_blocks, 696 int64_t n_blocks, 697 int64_t k_blocks) { 698 int64_t thread_block_k = (k_blocks + k_factor - 1) / k_factor; 699 int64_t thread_block_n = (n_blocks + n_factor - 1) / n_factor; 700 int64_t thread_block_m = (m_blocks + m_factor - 1) / m_factor; 701 return std::make_tuple(thread_block_m, thread_block_n, thread_block_k); 702 }; 703 704 auto is_better_blocking = [=](int64_t Mt_, 705 int64_t Nt_, 706 int64_t Kt_, 707 int64_t Mt, 708 int64_t Nt, 709 int64_t Kt) { 710 return Mt == 0 || Kt_ < Kt || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr; 711 }; 712 713 int64_t m_blocks = (M + Mr - 1) / Mr; 714 int64_t n_blocks = (N + Nr - 1) / Nr; 715 int64_t k_blocks = (K + Kr - 1) / Kr; 716 717 auto [factors, count] = get_factors(num_threads); 718 assert(count > 0); 719 720 for (int i = 0; i < count; ++i) { 721 int64_t n_factor = factors[i]; 722 int64_t m_factor = num_threads / n_factor; 723 if (n_blocks >= n_factor && m_blocks >= m_factor) { 724 auto [Mt_, Nt_, Kt_] = get_blocking( 725 m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); 726 if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { 727 std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); 728 } 729 } 730 } 731 732 if (Mt != 0) { 733 return; 734 } 735 736 for (int i = 0; i < count; ++i) { 737 int64_t k_factor = factors[i]; 738 if (k_blocks >= k_factor && (max_k_slices == 0 || k_factor <= max_k_slices)) { 739 auto [mxn_factors, mxn_count] = get_factors(num_threads / k_factor); 740 for (int j = 0; j < mxn_count; ++j) { 741 int64_t n_factor = mxn_factors[j]; 742 int64_t m_factor = num_threads / (k_factor * n_factor); 743 if (n_blocks >= n_factor && m_blocks >= m_factor) { 744 auto [Mt_, Nt_, Kt_] = get_blocking( 745 m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks); 746 if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { 747 std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); 748 } 749 } 750 } 751 } 752 } 753 754 if (Mt != 0) { 755 return; 756 } 757 758 for (int i = 0; i < count; ++i) { 759 int64_t n_factor = factors[i]; 760 int64_t m_factor = num_threads / n_factor; 761 if (n_blocks >= n_factor || m_blocks >= m_factor) { 762 auto [Mt_, Nt_, Kt_] = get_blocking( 763 m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); 764 if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { 765 std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); 766 } 767 } 768 } 769 770 assert(Mt != 0); 771 } 772 773 void mm_get_thread_blocking( 774 int num_threads, 775 int max_k_slices, 776 int64_t M, 777 int64_t N, 778 int64_t K, 779 int64_t Mr, 780 int64_t Nr, 781 int64_t Kr, 782 int64_t& Mt, 783 int64_t& Nt, 784 int64_t& Kt) { 785 thread_local std::map< 786 std::tuple<int, int, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>, 787 std::tuple<int64_t, int64_t, int64_t>> cache; 788 auto key = std::make_tuple(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr); 789 auto it = cache.find(key); 790 if (it != cache.end()) { 791 std::tie(Mt, Nt, Kt) = it->second; 792 return; 793 } else { 794 _mm_get_thread_blocking(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr, Mt, Nt, Kt); 795 cache[key] = std::make_tuple(Mt, Nt, Kt); 796 } 797 } 798 799 template<typename X_t, typename W_t> 800 void _mm_get_cache_blocking( 801 int num_threads, 802 int64_t M, 803 int64_t N, 804 int64_t K, 805 int64_t Mr, 806 int64_t Nr, 807 int64_t Kr, 808 int64_t Mt_blocks, 809 int64_t Nt_blocks, 810 int64_t Kt_blocks, 811 int64_t& Mc_blocks, 812 int64_t& Nc_blocks, 813 int64_t& Kc_blocks, 814 uint32_t L1_cache_size, 815 uint32_t L2_cache_size) { 816 // See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm. 817 // TODO(jgong5): cache cache blocking results 818 // TODO: tune the factor here 819 float L1_limit_factor = 0.8; 820 float L2_limit_factor = 0.5; 821 822 auto L1 = L1_cache_size * L1_limit_factor; 823 auto L2 = L2_cache_size * L2_limit_factor; 824 825 constexpr size_t num_byte_A = sizeof(X_t); 826 constexpr size_t num_byte_B = sizeof(W_t); 827 828 int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B; 829 Kc_blocks = Kt_blocks; 830 if (size_cache_B > L1) { 831 Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B)); 832 } 833 834 float min_Mc_ratio = 2; 835 int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr); 836 auto Kt_bytes = Kt_blocks * Kr * num_byte_A; 837 if (min_Mc_blocks * Mr * Kt_bytes < L2) { 838 Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes))); 839 Nc_blocks = 1; 840 } else { 841 Mc_blocks = Mt_blocks; 842 Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); 843 auto Nc_bytes = Nc_blocks * Nr * 4; 844 auto Kc_bytes = Kc_blocks * Kr * num_byte_A; 845 if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) { 846 auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8; 847 if (M_max < Mc_blocks * Mr) { 848 Mc_blocks = (int64_t)std::floor(M_max / Mr); 849 Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); 850 } 851 } 852 } 853 } 854 855 template<typename X_t, typename W_t> 856 void mm_get_cache_blocking( 857 int num_threads, 858 int64_t M, 859 int64_t N, 860 int64_t K, 861 int64_t Mr, 862 int64_t Nr, 863 int64_t Kr, 864 int64_t Mt_blocks, 865 int64_t Nt_blocks, 866 int64_t Kt_blocks, 867 int64_t& Mc_blocks, 868 int64_t& Nc_blocks, 869 int64_t& Kc_blocks, 870 uint32_t L1_cache_size, 871 uint32_t L2_cache_size) { 872 thread_local std::map< 873 std::tuple<int, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>, 874 std::tuple<int64_t, int64_t, int64_t>> cache; 875 auto key = std::make_tuple(num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, L1_cache_size, L2_cache_size); 876 auto it = cache.find(key); 877 if (it != cache.end()) { 878 std::tie(Mc_blocks, Nc_blocks, Kc_blocks) = it->second; 879 return; 880 } else { 881 _mm_get_cache_blocking<X_t, W_t>( 882 num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, Mc_blocks, Nc_blocks, Kc_blocks, L1_cache_size, L2_cache_size); 883 cache[key] = std::make_tuple(Mc_blocks, Nc_blocks, Kc_blocks); 884 } 885 } 886 887 inline void mm_get_thread_blocks( 888 int thread_id, 889 int64_t M_blocks, 890 int64_t N_blocks, 891 int64_t K_blocks, 892 int64_t Mt_blocks, 893 int64_t Nt_blocks, 894 int64_t Kt_blocks, 895 int64_t& m_block_start, 896 int64_t& m_block_end, 897 int64_t& n_block_start, 898 int64_t& n_block_end, 899 int64_t& k_block_start, 900 int64_t& k_block_end) { 901 int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; 902 k_block_start = (thread_id % num_Kt) * Kt_blocks; 903 k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); 904 thread_id /= num_Kt; 905 int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; 906 n_block_start = (thread_id % num_Nt) * Nt_blocks; 907 n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); 908 thread_id /= num_Nt; 909 m_block_start = std::min(thread_id * Mt_blocks, M_blocks); 910 m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); 911 } 912 913 struct amx_tilecfg { 914 uint8_t palette_id; 915 uint8_t start_row; 916 uint8_t reserved_0[14]; 917 uint16_t colsb[16]; 918 uint8_t rows[16]; 919 }; 920 921 class AMXState { 922 private: 923 amx_tilecfg tilecfg_; 924 uint8_t rows_; 925 uint16_t colsb_; 926 uint8_t num_tile_rows_; 927 uint8_t num_tile_columns_; 928 929 public: 930 AMXState() : rows_(0), colsb_(0), num_tile_rows_(0), num_tile_columns_(0) { 931 memset(&tilecfg_, 0, sizeof(tilecfg_)); 932 } 933 934 inline void configure( 935 uint8_t rows, 936 uint16_t colsb, 937 uint8_t num_tile_rows, 938 uint8_t num_tile_columns, 939 void (*loadconfig)(const amx_tilecfg&)) { 940 if (tilecfg_.palette_id == 1 && rows_ == rows && colsb_ == colsb && 941 num_tile_rows_ == num_tile_rows && 942 num_tile_columns_ == num_tile_columns) { 943 return; 944 } 945 tilecfg_.palette_id = 1; 946 rows_ = rows; 947 colsb_ = colsb; 948 num_tile_rows_ = num_tile_rows; 949 num_tile_columns_ = num_tile_columns; 950 const auto num_c_tiles = num_tile_rows * num_tile_columns; 951 // For C 952 for (int i = 0; i < num_c_tiles; i++) { 953 tilecfg_.rows[i] = rows; 954 tilecfg_.colsb[i] = 64; 955 } 956 // For A 957 for (int i = 0; i < num_tile_rows; i++) { 958 tilecfg_.rows[i + num_c_tiles] = rows; 959 tilecfg_.colsb[i + num_c_tiles] = colsb; 960 } 961 // For B 962 for (int i = 0; i < num_tile_columns; i++) { 963 tilecfg_.rows[i + num_c_tiles + num_tile_rows] = colsb / 4; 964 tilecfg_.colsb[i + num_c_tiles + num_tile_rows] = 64; 965 } 966 loadconfig(tilecfg_); 967 } 968 969 inline void release(void (*tile_release)()) { 970 tilecfg_.palette_id = 0; 971 tile_release(); 972 } 973 }; 974