xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cpp_prefix.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <algorithm>
4 #include <atomic>
5 #include <cmath>
6 #include <cstdlib>
7 #include <limits>
8 #include <memory>
9 #include <optional>
10 #include <map>
11 #include <omp.h>
12 
13 // WARNING: be extra careful when including more ATen/c10 header files here!
14 // Because AOTInductor generated code will copy-paste this cpp_prefix.h for
15 // the CPU backend, we have to make sure the used headers are implemented
16 // in a header-only way, i.e. all the function and class definitions are
17 // in .h files instead of .cpp files, to avoid ABI backward-compatiblity breakage.
18 
19 #include <ATen/NumericUtils.h>
20 #include <ATen/core/PhiloxRNGEngine.h>
21 
22 #include <c10/util/Float8_e4m3fn.h>
23 #include <c10/util/Float8_e5m2.h>
24 #include <c10/util/BFloat16.h>
25 #include <c10/util/BFloat16-math.h>
26 #include <c10/util/generic_math.h>
27 #include <c10/util/Half.h>
28 #include <c10/util/TypeCast.h>
29 
30 #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
31 #define INDUCTOR_USE_VECTOR_TYPES() 1
32 #else
33 #define INDUCTOR_USE_VECTOR_TYPES() 0
34 #endif
35 
36 #if INDUCTOR_USE_VECTOR_TYPES()
37 #include <ATen/cpu/vec/functional.h>
38 #include <ATen/cpu/vec/vec.h>
39 #else
40 // For calc_erfinv
41 #include <ATen/native/Math.h>
42 #endif
43 
44 typedef at::Half half;
45 typedef at::BFloat16 bfloat16;
46 
47 typedef at::Float8_e4m3fn float8_e4m3fn;
48 typedef at::Float8_e5m2 float8_e5m2;
49 
50 template <typename T>
51 struct Welford {
52   T mean = T(0);
53   T m2 = T(0);
54   // Use weight for tail cases since the index of each element in the vec may be
55   // different. A single index can not express masked welford reduction.
56   T weight = T(0);
57   uint64_t index = 0;
58 };
59 
60 
61 template <typename T>
62 struct IsVecType: std::false_type {};
63 
64 #if INDUCTOR_USE_VECTOR_TYPES()
65 template <typename T>
66 struct IsVecType<at::vec::Vectorized<T>>: std::true_type {};
67 #endif
68 
69 template <typename T>
70 struct WeightRecp {
71   using scalar_t = typename T::value_type;
72   std::vector<scalar_t> weight_recps;
73   WeightRecp(uint64_t N) {
74     weight_recps.reserve(N);
75     for (const auto i : c10::irange(N)) {
76       weight_recps.push_back(
77           scalar_t(static_cast<double>(1) / static_cast<double>(i + 1)));
78     }
79   }
80 };
81 
82 template <typename T>
83 Welford<T> welford_combine(const Welford<T>& a, const Welford<T>& b, bool use_index=false) {
84   if (a.index == 0) {
85     return b;
86   }
87   if (b.index == 0) {
88     return a;
89   }
90   auto delta = b.mean - a.mean;
91   auto a_weight = use_index ? T(a.index) : a.weight;
92   auto b_weight = use_index ? T(b.index) : b.weight;
93   auto new_weight = a_weight + b_weight;
94   auto new_index = a.index + b.index;
95   auto wb_over_w = b_weight / new_weight;
96   if constexpr (IsVecType<T>::value) {
97     // Guard against division by zero
98     wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0));
99   }
100   auto result = Welford<T>{
101     a.mean + delta * wb_over_w,
102     a.m2 + b.m2 + delta * delta * a_weight * wb_over_w,
103     new_weight,
104     new_index
105   };
106   return result;
107 }
108 
109 template <typename T>
110 Welford<T> welford_combine(const Welford<T>& acc, const T& data, const WeightRecp<T>* w=nullptr) {
111   // Add a single data point
112   uint64_t new_index = acc.index + 1;
113   auto new_weight = acc.weight + T(1);
114   auto delta = data - acc.mean;
115   T new_mean;
116   if constexpr (!IsVecType<T>::value) {
117     new_mean = acc.mean + delta / new_weight;
118   } else {
119     // use new_index to fecth 1 / new_weight to avoid divisions
120     new_mean = acc.mean +
121       ((w == nullptr || acc.index >= w->weight_recps.size())
122             ? delta / new_weight
123             : delta * T(w->weight_recps[acc.index]));
124   }
125   auto new_delta = data - new_mean;
126   auto result = Welford<T>{
127     new_mean,
128     acc.m2 + delta * new_delta,
129     new_weight,
130     new_index
131   };
132   return result;
133 }
134 
135 template <typename T>
136 struct IndexValue {
137   int64_t index;
138   T value;
139   IndexValue(int64_t idx, T val) :index(idx), value(val) {};
140   IndexValue() {};
141 };
142 
143 #if INDUCTOR_USE_VECTOR_TYPES()
144 template <typename T>
145 Welford<T> welford_combine(const Welford<T>& acc, const T& data, const int64_t tail_size, const WeightRecp<T>* w=nullptr) {
146   auto out = welford_combine(acc, data, w);
147   return Welford<T>{
148     T::set(acc.mean, out.mean, tail_size),
149     T::set(acc.m2, out.m2, tail_size),
150     T::set(acc.weight, out.weight, tail_size),
151     out.index
152   };
153 }
154 
155 template <typename T>
156 T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
157   auto out = at::vec::maximum(a, b);
158   return T::set(a, out, tail_size);
159 }
160 
161 template <typename T>
162 T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
163   auto out = at::vec::minimum(a, b);
164   return T::set(a, out, tail_size);
165 }
166 
167 template <typename T>
168 T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
169   auto out = a + b;
170   return T::set(a, out, tail_size);
171 }
172 
173 template <typename T>
174 T prod_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
175   auto out = a * b;
176   return T::set(a, out, tail_size);
177 }
178 
179 template <typename T>
180 T xor_sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
181   auto out = a ^ b;
182   return T::set(a, out, tail_size);
183 }
184 #endif
185 
186 // Refer to https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/
187 // aten/src/ATen/native/SharedReduceOps.h#L419-L445
188 template <typename scalar_t>
189 inline bool greater_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) {
190   // If (a == b), then choose the one with lower idx, else max(a, b)
191   if (at::_isnan(a)) {
192     if (at::_isnan(b)) {
193       return idx_a < idx_b;
194     }
195     return true;
196   }
197   return (a == b) ? idx_a < idx_b : (a > b);
198 }
199 
200 template <typename scalar_t>
201 inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) {
202   // If (a == b), then choose the one with lower idx, else min(a, b)
203   if (at::_isnan(a)) {
204     if (at::_isnan(b)) {
205       return idx_a < idx_b;
206     }
207     return true;
208   }
209   return (a == b) ? idx_a < idx_b : (a < b);
210 }
211 
212 template <typename T>
213 inline IndexValue<T>& argmin_combine(IndexValue<T>& a, T next_value, int64_t next_index){
214   if(!(less_or_nan(a.value, next_value, a.index, next_index))){
215     a.value = next_value;
216     a.index = next_index;
217   }
218   return a;
219 }
220 template <typename T>
221 inline IndexValue<T>& argmax_combine(IndexValue<T>& a, T next_value, int64_t next_index){
222   if(!(greater_or_nan(a.value, next_value, a.index, next_index))){
223     a.value = next_value;
224     a.index = next_index;
225   }
226   return a;
227 }
228 template <typename T>
229 inline IndexValue<T>& argmin_combine(IndexValue<T>& a, const IndexValue<T>& next){
230   return argmin_combine(a, next.value, next.index);
231 }
232 template <typename T>
233 inline IndexValue<T>& argmax_combine(IndexValue<T>& a, const IndexValue<T>& next){
234   return argmax_combine(a, next.value, next.index);
235 }
236 
237 #if INDUCTOR_USE_VECTOR_TYPES()
238 
239 template <typename scalar_t>
240 inline at::vec::Vectorized<scalar_t> div_floor_floating_vec(
241     const at::vec::Vectorized<scalar_t>& a,
242     const at::vec::Vectorized<scalar_t>& b) {
243   using vec_t = at::vec::Vectorized<scalar_t>;
244   const auto basic_div = a / b;
245   vec_t inf(std::numeric_limits<scalar_t>::infinity());
246   auto mod = a.fmod(b);
247   // Fixup for a case that isn't properly handled by Sleef_fmod
248   auto floor = vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf));
249   auto div = floor / b;
250   const auto zero = vec_t(0);
251   auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
252   const auto one = vec_t(1);
253   div = vec_t::blendv(div, div - one, mask);
254   auto floordiv = div.floor();
255   mask = (div - floordiv) > vec_t(0.5);
256   floordiv = vec_t::blendv(floordiv, floordiv + one, mask);
257   floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero);
258   floordiv = vec_t::blendv(floordiv, basic_div, b == zero);
259   return floordiv;
260 };
261 
262 template <typename scalar_t, int N>
263 inline at::vec::VectorizedN<scalar_t, N> div_floor_floating_vec(
264     const at::vec::VectorizedN<scalar_t, N>& a,
265     const at::vec::VectorizedN<scalar_t, N>& b) {
266     at::vec::VectorizedN<scalar_t, N> result;
267 #ifndef _MSC_VER
268 #pragma unroll
269 #endif
270     for (int i = 0; i < N; ++i) {
271       result[i] = div_floor_floating_vec(a[i], b[i]);
272     }
273     return result;
274 }
275 
276 template <typename T, int NV, int NI>
277 struct IndexValueVec {
278   at::vec::VectorizedN<T, NV> value;
279   at::vec::VectorizedN<int64_t, NI> index;
280 
281   IndexValueVec(const T _value) {
282     value = at::vec::VectorizedN<T, NV>(_value);
283     index = at::vec::VectorizedN<int64_t, NI>(0);
284   };
285 
286   IndexValueVec() {};
287 };
288 
289 
290 template <typename T, int NV, int NI,
291           typename std::enable_if_t<at::vec::is_floating_point_v<T>, int> = 0>
292 at::vec::VecMask<int64_t, NI> inline get_mask_for_argmin_argmax(
293   const at::vec::VecMask<T, NV>& vmask,
294   const IndexValueVec<T, NV, NI>& a,
295   const at::vec::VectorizedN<T, NV>& value,
296   const at::vec::VectorizedN<int64_t, NI>& index
297 ){
298   /*
299   vec impl for less_or_nan and greater_or_nan
300   example for argmin:
301   a.value = [NaN, NaN, 0, 2, 1, 0]
302   value = [NaN, 0, 0, 1, 2, NaN]
303   vmask = [false, false, false, false, true, false]
304   all_nan_or_equal = [true, false, true, false, false, false]
305   imask = [a.index[0] < index[0], ..., a.index[-1] < index[-1]]
306   iv_mask = blendv (vmask, imask, all_nan_or_equal)
307           [a.index[0] < index[0], false, a.index[2] < index[2], false, true, false]
308   a_nan_b_not: [false, false, false, false, false, true]
309   mask = iv_mask | a_nan_b_not
310           [a.index[0] < index[0], false, a.index[2] < index[2], false, true, true]
311   */
312   using v_t = at::vec::VecMask<T, NV>;
313   using i_t = at::vec::VecMask<int64_t, NI>;
314   i_t vmask_itype = vmask.template cast<int64_t, NI>();
315   // use itype here since there is vec impl for operator~ for itype
316   // while there may not vec impl for vtype
317   v_t isnan_a = a.value.isnan();
318   i_t isnan_a_itype = isnan_a.template cast<int64_t, NI>();
319   v_t isnan_b = value.isnan();
320   i_t isnan_b_type = isnan_b.template cast<int64_t, NI>();
321   i_t all_nan_mask = isnan_a_itype & isnan_b_type;
322   v_t equal_mask = (a.value == value);
323   i_t equal_mask_itype = equal_mask.template cast<int64_t, NI>();
324   i_t all_nan_or_equal = all_nan_mask | equal_mask_itype;
325   i_t imask(a.index < index);
326   i_t iv_mask = i_t::blendv(vmask_itype, imask, all_nan_or_equal);
327   i_t isnan_a_notnan_b = isnan_a_itype & (~isnan_b_type);
328   return iv_mask | isnan_a_notnan_b;
329 }
330 
331 template <typename T, int NV, int NI,
332           typename std::enable_if_t<!at::vec::is_floating_point_v<T>, int> = 0>
333 at::vec::VecMask<int64_t, NI> inline get_mask_for_argmin_argmax(
334   const at::vec::VecMask<T, NV>& vmask,
335   const IndexValueVec<T, NV, NI>& a,
336   const at::vec::VectorizedN<T, NV>& value,
337   const at::vec::VectorizedN<int64_t, NI>& index
338 ){
339   using v_t = at::vec::VecMask<T, NV>;
340   using i_t = at::vec::VecMask<int64_t, NI>;
341   i_t vmask_itype = vmask.template cast<int64_t, NI>();
342   v_t equal_mask = (a.value == value);
343   i_t equal_mask_itype = equal_mask.template cast<int64_t, NI>();
344   i_t imask(a.index < index);
345   return i_t::blendv(vmask_itype, imask, equal_mask_itype);
346 }
347 
348 
349 template <typename T, int NV, int NI>
350 inline IndexValueVec<T, NV, NI>& argmin_vec_impl(IndexValueVec<T, NV, NI>& a,  at::vec::VectorizedN<T, NV> value, at::vec::VectorizedN<int64_t, NI> index, std::optional<int64_t> tail_size){
351   at::vec::VecMask<T, NV> vmask(a.value < value);
352   at::vec::VecMask<int64_t, NI> final_mask = get_mask_for_argmin_argmax<T, NV, NI>(vmask, a, value, index);
353   if (tail_size.has_value()) {
354     a.value = at::vec::VectorizedN<T, NV>::set(a.value, at::vec::minimum(a.value, value), tail_size.value());
355     a.index = at::vec::VectorizedN<int64_t, NI>::set(a.index, at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask), tail_size.value());
356   } else {
357     a.value = at::vec::minimum(a.value, value);
358     a.index = at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask);
359   }
360   return a;
361 }
362 
363 template <typename T, int NV, int NI>
364 inline IndexValueVec<T, NV, NI>& argmax_vec_impl(IndexValueVec<T, NV, NI>& a,  at::vec::VectorizedN<T, NV> value, at::vec::VectorizedN<int64_t, NI> index, std::optional<int64_t> tail_size){
365   at::vec::VecMask<T, NV> vmask(a.value > value);
366   at::vec::VecMask<int64_t, NI> final_mask = get_mask_for_argmin_argmax<T, NV, NI>(vmask, a, value, index);
367   if (tail_size.has_value()) {
368     a.value = at::vec::VectorizedN<T, NV>::set(a.value, at::vec::maximum(a.value, value), tail_size.value());
369     a.index = at::vec::VectorizedN<int64_t, NI>::set(a.index, at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask), tail_size.value());
370   } else {
371     a.value = at::vec::maximum(a.value, value);
372     a.index = at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask);
373   }
374   return a;
375 }
376 
377 template <typename T, int NI, bool horizontal>
378 inline at::vec::VectorizedN<int64_t, NI> create_index(int64_t next_index){
379   at::vec::VectorizedN<int64_t, NI> next_idx;
380   if constexpr (horizontal) {
381     next_idx = at::vec::VectorizedN<int64_t, NI>::arange(next_index, 1);
382   } else {
383     next_idx = at::vec::VectorizedN<int64_t, NI>(next_index);
384   }
385   return next_idx;
386 }
387 
388 template <typename T, int NV, int NI, bool horizontal>
389 inline IndexValueVec<T, NV, NI>& argmin_combine_vec(IndexValueVec<T, NV, NI>& a, at::vec::VectorizedN<T, NV> next_value, int64_t next_index, std::optional<int64_t> tail_size = std::nullopt){
390   auto next_idx = create_index<T, NI, horizontal>(next_index);
391   return argmin_vec_impl(a, next_value, next_idx, tail_size);
392 }
393 
394 template <typename T, int NV, int NI, bool horizontal>
395 inline IndexValueVec<T, NV, NI>& argmax_combine_vec(IndexValueVec<T, NV, NI>& a, at::vec::VectorizedN<T, NV> next_value, int64_t next_index, std::optional<int64_t> tail_size = std::nullopt){
396   auto next_idx = create_index<T, NI, horizontal>(next_index);
397   return argmax_vec_impl(a, next_value, next_idx, tail_size);
398 }
399 
400 template <typename T, int NV, int NI>
401 inline IndexValue<T> argmin_vec_reduce_all(const IndexValueVec<T, NV, NI>& vec){
402   constexpr int len = at::vec::VectorizedN<T, NV>::size();
403   __at_align__ T tmpval[len];
404   __at_align__ int64_t tmpidx[len];
405   vec.value.store(tmpval);
406   vec.index.store(tmpidx);
407   IndexValue res = IndexValue<T>(tmpidx[0], tmpval[0]);
408   for (int i = 1; i < len; i++){
409     res = argmin_combine(res, tmpval[i], tmpidx[i]);
410   }
411   return res;
412 }
413 
414 template <typename T, int NV, int NI>
415 inline IndexValue<T> argmax_vec_reduce_all(const IndexValueVec<T, NV, NI>& vec){
416   constexpr int len = at::vec::VectorizedN<T, NV>::size();
417   __at_align__ T tmpval[len];
418   __at_align__ int64_t tmpidx[len];
419   vec.value.store(tmpval);
420   vec.index.store(tmpidx);
421   IndexValue res = IndexValue<T>(tmpidx[0], tmpval[0]);
422   for (int i = 1; i < len; i++){
423     res = argmax_combine(res, tmpval[i], tmpidx[i]);
424   }
425   return res;
426 }
427 
428 template <typename T, int NV, int NI>
429 inline IndexValueVec<T, NV, NI>& argmin_combine_vec(IndexValueVec<T, NV, NI>& vec_a, const IndexValueVec<T, NV, NI>& vec_b, std::optional<int64_t> tail_size = std::nullopt){
430   return argmin_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size);
431 }
432 
433 template <typename T, int NV, int NI>
434 inline IndexValueVec<T, NV, NI>& argmax_combine_vec(IndexValueVec<T, NV, NI>& vec_a, const IndexValueVec<T, NV, NI>& vec_b, std::optional<int64_t> tail_size = std::nullopt){
435   return argmax_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size);
436 }
437 
438 template <typename scalar_t>
439 inline at::vec::Vectorized<scalar_t> vec_shuffle_down(at::vec::Vectorized<scalar_t> x, size_t n) {
440   using Vec = at::vec::Vectorized<scalar_t>;
441   alignas(alignof(Vec)) scalar_t array[Vec::size()];
442   x.store(array);
443   for (size_t i = 0; i + n < Vec::size(); i += 2 * n) {
444     array[i] = array[i + n];
445   }
446   return Vec::loadu(array);
447 }
448 
449 #ifdef CPU_CAPABILITY_AVX2
450 inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x, size_t n) {
451   using vec_t = at::vec::Vectorized<float>;
452 #define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w)
453   switch (n) {
454   case 1:
455     return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3)));
456   case 2:
457     return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2)));
458   case 4:
459     return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1)));
460   }
461   TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n);
462 }
463 #endif
464 
465 #ifdef CPU_CAPABILITY_AVX512
466 inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x, size_t n) {
467   using vec_t = at::vec::Vectorized<float>;
468 #define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w)
469   switch (n) {
470     case 1:
471       return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3)));
472     case 2:
473       return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2)));
474     case 4:
475       return vec_t(_mm512_permutexvar_ps(
476           _mm512_set_epi32(
477               12, 12, 12, 12, 12, 12, 12, 12, 4, 4, 4, 4, 4, 4, 4, 4),
478           x));
479     case 8:
480       return vec_t(_mm512_permutexvar_ps(
481           _mm512_set_epi32(8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8), x));
482   }
483   TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n);
484 }
485 #endif
486 
487 template <typename scalar_t>
488 Welford<scalar_t> welford_vec_reduce_all(Welford<at::vec::Vectorized<scalar_t>> acc) {
489   using Vec = at::vec::Vectorized<scalar_t>;
490   Welford<scalar_t> result;
491   if (acc.index == 0) {
492     return result;
493   }
494   // if all values of acc.weight are same as index,
495   // use index to reduce to save the overhead of vec_shuffle_down for acc.weight
496   bool use_index = (acc.weight - Vec(acc.index)).zero_mask() == static_cast<int>((1 << Vec::size()) - 1);
497   for (size_t n = 1; n < Vec::size(); n *= 2) {
498     auto shuffled = Welford<Vec>{
499       vec_shuffle_down(acc.mean, n),
500       vec_shuffle_down(acc.m2, n),
501       use_index ? Vec(0) : vec_shuffle_down(acc.weight, n),
502       acc.index};
503     acc = welford_combine(acc, shuffled, use_index);
504   }
505 
506   alignas(alignof(Vec)) scalar_t array[Vec::size()];
507   acc.mean.store(array);
508   result.mean = array[0];
509 
510   acc.m2.store(array);
511   result.m2 = array[0];
512 
513   acc.weight.store(array);
514   result.weight = array[0];
515   result.index = result.weight;
516 
517   return result;
518 }
519 
520 template <typename scalar_t>
521 Welford<scalar_t> welford_vec_reduce_all(Welford<at::vec::VectorizedN<scalar_t, 2>> acc) {
522   auto Welford0 = Welford<at::vec::Vectorized<scalar_t>>{
523     acc.mean[0],
524     acc.m2[0],
525     acc.weight[0],
526     acc.index
527   };
528   auto Welford1 = Welford<at::vec::Vectorized<scalar_t>>{
529     acc.mean[1],
530     acc.m2[1],
531     acc.weight[1],
532     acc.index
533   };
534   return welford_vec_reduce_all(welford_combine(Welford0, Welford1));
535 }
536 #endif
537 
538 
539 template <typename T, typename U> inline typename std::common_type<T, U>::type mod(T a, U b) { return a % b; }
540 template <> inline float mod(float a, float b) { return std::fmod(a, b); }
541 template <> inline double mod(double a, double b) { return std::fmod(a, b); }
542 
543 template <typename scalar_t>
544 inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
545   if (at::_isnan(a)) {
546     return a;
547   }
548   return a > b ? a : b;
549 }
550 
551 template <typename scalar_t>
552 inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
553   if (at::_isnan(a)) {
554     return a;
555   }
556   return a < b ? a : b;
557 }
558 
559 constexpr float uint32_to_uniform_float(uint32_t value) {
560   // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
561   constexpr float scale = 4.6566127342e-10;
562   return static_cast<float>(value & 0x7FFFFFFF) * scale;
563 }
564 
565 float normalized_rand_cpu(uint32_t seed, uint32_t offset) {
566   return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)());
567 }
568 
569 float randn_cpu(uint32_t seed, uint32_t offset) {
570   at::Philox4_32 engine(seed, 0, offset);
571   return engine.randn(10);
572 }
573 
574 int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_t high) {
575   auto gen = at::Philox4_32(seed, 0, offset);
576   uint64_t r0 = gen();
577   uint64_t r1 = gen();
578   uint64_t result = r0 | (r1 << 32);
579   return static_cast<int64_t>(result % (high - low)) + low;
580 }
581 
582 template <typename T> struct AsIntegerType { typedef T type; };
583 template <> struct AsIntegerType<float> { typedef uint32_t type; };
584 template <> struct AsIntegerType<double> { typedef uint64_t type; };
585 template <> struct AsIntegerType<bfloat16> { typedef uint16_t type; };
586 
587 template <typename T>
588 typename std::enable_if_t<!std::is_reduced_floating_point_v<T>, T>
589 inline fetch_value(volatile T *addr) {
590   return *addr;
591 }
592 
593 template <typename T>
594 typename std::enable_if_t<std::is_reduced_floating_point_v<T>, T>
595 inline fetch_value(volatile T *addr) {
596   return T(addr->x, T::from_bits());
597 }
598 
599 template <typename T>
600 typename std::enable_if_t<!std::is_integral_v<T>>
601 atomic_add(volatile T *addr, T offset) {
602   typedef typename AsIntegerType<T>::type alt_type;
603 
604   static_assert(sizeof(std::atomic<alt_type>) == sizeof(T),
605                 "std::atomic issue");
606 
607   alt_type expected;
608 
609   alt_type desired;
610 
611   std::atomic<alt_type> *atomic_addr = (std::atomic<alt_type> *)addr;
612   do {
613     T val = fetch_value(addr);
614     reinterpret_cast<T *>(&expected)[0] = val;
615     reinterpret_cast<T *>(&desired)[0] = val + offset;
616   } while (!atomic_addr->compare_exchange_weak(expected, desired,
617                                                std::memory_order_relaxed));
618 }
619 
620 // Since C++20 float is supported by fetch_add, but the performance may not
621 // better than compare_exchange_weak, which can be checked by microbenchmark
622 // inductor_cpu_atomic.py
623 template <typename T>
624 typename std::enable_if_t<std::is_integral_v<T>>
625 atomic_add(volatile T *addr, T offset) {
626   static_assert(sizeof(std::atomic<T>) == sizeof(T),
627                 "std::atomic issue");
628   std::atomic<T> *atomic_addr = (std::atomic<T> *)addr;
629   atomic_addr->fetch_add(offset, std::memory_order_relaxed);
630 }
631 
632 #if INDUCTOR_USE_VECTOR_TYPES()
633 template <typename T, int NI, int NV>
634 void atomic_add_vec(T *addr, at::vec::VectorizedN<int64_t, NI> index, at::vec::VectorizedN<T, NV> offset) {
635   constexpr int len = at::vec::VectorizedN<int64_t, NI>::size();
636   static_assert(len <= at::vec::VectorizedN<T, NV>::size());
637   __at_align__ std::array<T, len> tmpbuf;
638   __at_align__ std::array<int64_t, len> tmpidx;
639   offset.store(tmpbuf.data());
640   index.store(tmpidx.data());
641   for (int i = 0; i < len; i++){
642     atomic_add(addr + tmpidx[i], tmpbuf[i]);
643   }
644 }
645 #endif
646 
647 std::tuple<std::shared_ptr<int64_t[]>, int> _get_factors(int64_t number) {
648   int count = 0;
649   for (int64_t i = std::sqrt(number); i > 0; --i) {
650     if (number % i == 0) {
651       count += 2;
652     }
653   }
654   auto factors = std::shared_ptr<int64_t[]>(new int64_t[count]);
655   int index = 0;
656   for (int64_t i = std::sqrt(number); i > 0; --i) {
657     if (number % i == 0) {
658       factors[index++] = number / i;
659       factors[index++] = i;
660     }
661   }
662   return std::make_tuple(factors, count);
663 }
664 
665 std::tuple<std::shared_ptr<int64_t[]>, int> get_factors(int64_t number) {
666   thread_local std::map<int64_t, std::tuple<std::shared_ptr<int64_t[]>, int>> cache;
667   auto it = cache.find(number);
668   if (it != cache.end()) {
669     return it->second;
670   } else {
671     auto factors = _get_factors(number);
672     cache[number] = factors;
673     return factors;
674   }
675 }
676 
677 void _mm_get_thread_blocking(
678     int num_threads,
679     int max_k_slices,
680     int64_t M,
681     int64_t N,
682     int64_t K,
683     int64_t Mr,
684     int64_t Nr,
685     int64_t Kr,
686     int64_t& Mt,
687     int64_t& Nt,
688     int64_t& Kt) {
689   // see NOTE [Thread blocking in Cpp GEMM] for heuristics
690   Mt = Nt = Kt = 0;
691 
692   auto get_blocking = [](int64_t m_factor,
693                          int64_t n_factor,
694                          int64_t k_factor,
695                          int64_t m_blocks,
696                          int64_t n_blocks,
697                          int64_t k_blocks) {
698     int64_t thread_block_k = (k_blocks + k_factor - 1) / k_factor;
699     int64_t thread_block_n = (n_blocks + n_factor - 1) / n_factor;
700     int64_t thread_block_m = (m_blocks + m_factor - 1) / m_factor;
701     return std::make_tuple(thread_block_m, thread_block_n, thread_block_k);
702   };
703 
704   auto is_better_blocking = [=](int64_t Mt_,
705                               int64_t Nt_,
706                               int64_t Kt_,
707                               int64_t Mt,
708                               int64_t Nt,
709                               int64_t Kt) {
710     return Mt == 0 || Kt_ < Kt || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr;
711   };
712 
713   int64_t m_blocks = (M + Mr - 1) / Mr;
714   int64_t n_blocks = (N + Nr - 1) / Nr;
715   int64_t k_blocks = (K + Kr - 1) / Kr;
716 
717   auto [factors, count] = get_factors(num_threads);
718   assert(count > 0);
719 
720   for (int i = 0; i < count; ++i) {
721     int64_t n_factor = factors[i];
722     int64_t m_factor = num_threads / n_factor;
723     if (n_blocks >= n_factor && m_blocks >= m_factor) {
724       auto [Mt_, Nt_, Kt_] = get_blocking(
725           m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks);
726       if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) {
727         std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_);
728       }
729     }
730   }
731 
732   if (Mt != 0) {
733     return;
734   }
735 
736   for (int i = 0; i < count; ++i) {
737     int64_t k_factor = factors[i];
738     if (k_blocks >= k_factor && (max_k_slices == 0 || k_factor <= max_k_slices)) {
739       auto [mxn_factors, mxn_count] = get_factors(num_threads / k_factor);
740       for (int j = 0; j < mxn_count; ++j) {
741         int64_t n_factor = mxn_factors[j];
742         int64_t m_factor = num_threads / (k_factor * n_factor);
743         if (n_blocks >= n_factor && m_blocks >= m_factor) {
744           auto [Mt_, Nt_, Kt_] = get_blocking(
745               m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks);
746           if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) {
747             std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_);
748           }
749         }
750       }
751     }
752   }
753 
754   if (Mt != 0) {
755     return;
756   }
757 
758   for (int i = 0; i < count; ++i) {
759     int64_t n_factor = factors[i];
760     int64_t m_factor = num_threads / n_factor;
761     if (n_blocks >= n_factor || m_blocks >= m_factor) {
762       auto [Mt_, Nt_, Kt_] = get_blocking(
763           m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks);
764       if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) {
765         std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_);
766       }
767     }
768   }
769 
770   assert(Mt != 0);
771 }
772 
773 void mm_get_thread_blocking(
774     int num_threads,
775     int max_k_slices,
776     int64_t M,
777     int64_t N,
778     int64_t K,
779     int64_t Mr,
780     int64_t Nr,
781     int64_t Kr,
782     int64_t& Mt,
783     int64_t& Nt,
784     int64_t& Kt) {
785   thread_local std::map<
786     std::tuple<int, int, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>,
787     std::tuple<int64_t, int64_t, int64_t>> cache;
788   auto key = std::make_tuple(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr);
789   auto it = cache.find(key);
790   if (it != cache.end()) {
791     std::tie(Mt, Nt, Kt) = it->second;
792     return;
793   } else {
794     _mm_get_thread_blocking(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr, Mt, Nt, Kt);
795     cache[key] = std::make_tuple(Mt, Nt, Kt);
796   }
797 }
798 
799 template<typename X_t, typename W_t>
800 void _mm_get_cache_blocking(
801     int num_threads,
802     int64_t M,
803     int64_t N,
804     int64_t K,
805     int64_t Mr,
806     int64_t Nr,
807     int64_t Kr,
808     int64_t Mt_blocks,
809     int64_t Nt_blocks,
810     int64_t Kt_blocks,
811     int64_t& Mc_blocks,
812     int64_t& Nc_blocks,
813     int64_t& Kc_blocks,
814     uint32_t L1_cache_size,
815     uint32_t L2_cache_size) {
816   // See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm.
817   // TODO(jgong5): cache cache blocking results
818   // TODO: tune the factor here
819   float L1_limit_factor = 0.8;
820   float L2_limit_factor = 0.5;
821 
822   auto L1 = L1_cache_size * L1_limit_factor;
823   auto L2 = L2_cache_size * L2_limit_factor;
824 
825   constexpr size_t num_byte_A = sizeof(X_t);
826   constexpr size_t num_byte_B = sizeof(W_t);
827 
828   int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B;
829   Kc_blocks = Kt_blocks;
830   if (size_cache_B > L1) {
831       Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B));
832   }
833 
834   float min_Mc_ratio = 2;
835   int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr);
836   auto Kt_bytes = Kt_blocks * Kr * num_byte_A;
837   if (min_Mc_blocks * Mr * Kt_bytes < L2) {
838     Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes)));
839     Nc_blocks = 1;
840   } else {
841     Mc_blocks = Mt_blocks;
842     Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
843     auto Nc_bytes = Nc_blocks * Nr * 4;
844     auto Kc_bytes = Kc_blocks * Kr * num_byte_A;
845     if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) {
846       auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8;
847       if (M_max < Mc_blocks * Mr) {
848         Mc_blocks = (int64_t)std::floor(M_max / Mr);
849         Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
850       }
851     }
852   }
853 }
854 
855 template<typename X_t, typename W_t>
856 void mm_get_cache_blocking(
857     int num_threads,
858     int64_t M,
859     int64_t N,
860     int64_t K,
861     int64_t Mr,
862     int64_t Nr,
863     int64_t Kr,
864     int64_t Mt_blocks,
865     int64_t Nt_blocks,
866     int64_t Kt_blocks,
867     int64_t& Mc_blocks,
868     int64_t& Nc_blocks,
869     int64_t& Kc_blocks,
870     uint32_t L1_cache_size,
871     uint32_t L2_cache_size) {
872   thread_local std::map<
873     std::tuple<int, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>,
874     std::tuple<int64_t, int64_t, int64_t>> cache;
875   auto key = std::make_tuple(num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, L1_cache_size, L2_cache_size);
876   auto it = cache.find(key);
877   if (it != cache.end()) {
878     std::tie(Mc_blocks, Nc_blocks, Kc_blocks) = it->second;
879     return;
880   } else {
881     _mm_get_cache_blocking<X_t, W_t>(
882         num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, Mc_blocks, Nc_blocks, Kc_blocks, L1_cache_size, L2_cache_size);
883     cache[key] = std::make_tuple(Mc_blocks, Nc_blocks, Kc_blocks);
884   }
885 }
886 
887 inline void mm_get_thread_blocks(
888     int thread_id,
889     int64_t M_blocks,
890     int64_t N_blocks,
891     int64_t K_blocks,
892     int64_t Mt_blocks,
893     int64_t Nt_blocks,
894     int64_t Kt_blocks,
895     int64_t& m_block_start,
896     int64_t& m_block_end,
897     int64_t& n_block_start,
898     int64_t& n_block_end,
899     int64_t& k_block_start,
900     int64_t& k_block_end) {
901   int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks;
902   k_block_start = (thread_id % num_Kt) * Kt_blocks;
903   k_block_end = std::min(k_block_start + Kt_blocks, K_blocks);
904   thread_id /= num_Kt;
905   int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks;
906   n_block_start = (thread_id % num_Nt) * Nt_blocks;
907   n_block_end = std::min(n_block_start + Nt_blocks, N_blocks);
908   thread_id /= num_Nt;
909   m_block_start = std::min(thread_id * Mt_blocks, M_blocks);
910   m_block_end = std::min(m_block_start + Mt_blocks, M_blocks);
911 }
912 
913 struct amx_tilecfg {
914   uint8_t palette_id;
915   uint8_t start_row;
916   uint8_t reserved_0[14];
917   uint16_t colsb[16];
918   uint8_t rows[16];
919 };
920 
921 class AMXState {
922  private:
923   amx_tilecfg tilecfg_;
924   uint8_t rows_;
925   uint16_t colsb_;
926   uint8_t num_tile_rows_;
927   uint8_t num_tile_columns_;
928 
929  public:
930   AMXState() : rows_(0), colsb_(0), num_tile_rows_(0), num_tile_columns_(0) {
931     memset(&tilecfg_, 0, sizeof(tilecfg_));
932   }
933 
934   inline void configure(
935       uint8_t rows,
936       uint16_t colsb,
937       uint8_t num_tile_rows,
938       uint8_t num_tile_columns,
939       void (*loadconfig)(const amx_tilecfg&)) {
940     if (tilecfg_.palette_id == 1 && rows_ == rows && colsb_ == colsb &&
941         num_tile_rows_ == num_tile_rows &&
942         num_tile_columns_ == num_tile_columns) {
943       return;
944     }
945     tilecfg_.palette_id = 1;
946     rows_ = rows;
947     colsb_ = colsb;
948     num_tile_rows_ = num_tile_rows;
949     num_tile_columns_ = num_tile_columns;
950     const auto num_c_tiles = num_tile_rows * num_tile_columns;
951     // For C
952     for (int i = 0; i < num_c_tiles; i++) {
953       tilecfg_.rows[i] = rows;
954       tilecfg_.colsb[i] = 64;
955     }
956     // For A
957     for (int i = 0; i < num_tile_rows; i++) {
958       tilecfg_.rows[i + num_c_tiles] = rows;
959       tilecfg_.colsb[i + num_c_tiles] = colsb;
960     }
961     // For B
962     for (int i = 0; i < num_tile_columns; i++) {
963       tilecfg_.rows[i + num_c_tiles + num_tile_rows] = colsb / 4;
964       tilecfg_.colsb[i + num_c_tiles + num_tile_rows] = 64;
965     }
966     loadconfig(tilecfg_);
967   }
968 
969   inline void release(void (*tile_release)()) {
970     tilecfg_.palette_id = 0;
971     tile_release();
972   }
973 };
974