xref: /aosp_15_r20/external/pytorch/caffe2/perfkernels/common.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker // !!!! PLEASE READ !!!!
2*da0073e9SAndroid Build Coastguard Worker // Minimize (transitively) included headers from _avx*.cc because some of the
3*da0073e9SAndroid Build Coastguard Worker // functions defined in the headers compiled with platform dependent compiler
4*da0073e9SAndroid Build Coastguard Worker // options can be reused by other translation units generating illegal
5*da0073e9SAndroid Build Coastguard Worker // instruction run-time error.
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker // Common utilities for writing performance kernels and easy dispatching of
8*da0073e9SAndroid Build Coastguard Worker // different backends.
9*da0073e9SAndroid Build Coastguard Worker /*
10*da0073e9SAndroid Build Coastguard Worker The general workflow shall be as follows, say we want to
11*da0073e9SAndroid Build Coastguard Worker implement a functionality called void foo(int a, float b).
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker In foo.h, do:
14*da0073e9SAndroid Build Coastguard Worker    void foo(int a, float b);
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker In foo_avx512.cc, do:
17*da0073e9SAndroid Build Coastguard Worker    void foo__avx512(int a, float b) {
18*da0073e9SAndroid Build Coastguard Worker      [actual avx512 implementation]
19*da0073e9SAndroid Build Coastguard Worker    }
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker In foo_avx2.cc, do:
22*da0073e9SAndroid Build Coastguard Worker    void foo__avx2(int a, float b) {
23*da0073e9SAndroid Build Coastguard Worker      [actual avx2 implementation]
24*da0073e9SAndroid Build Coastguard Worker    }
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker In foo_avx.cc, do:
27*da0073e9SAndroid Build Coastguard Worker    void foo__avx(int a, float b) {
28*da0073e9SAndroid Build Coastguard Worker      [actual avx implementation]
29*da0073e9SAndroid Build Coastguard Worker    }
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker In foo.cc, do:
32*da0073e9SAndroid Build Coastguard Worker    // The base implementation should *always* be provided.
33*da0073e9SAndroid Build Coastguard Worker    void foo__base(int a, float b) {
34*da0073e9SAndroid Build Coastguard Worker      [base, possibly slow implementation]
35*da0073e9SAndroid Build Coastguard Worker    }
36*da0073e9SAndroid Build Coastguard Worker    decltype(foo__base) foo__avx512;
37*da0073e9SAndroid Build Coastguard Worker    decltype(foo__base) foo__avx2;
38*da0073e9SAndroid Build Coastguard Worker    decltype(foo__base) foo__avx;
39*da0073e9SAndroid Build Coastguard Worker    void foo(int a, float b) {
40*da0073e9SAndroid Build Coastguard Worker      // You should always order things by their preference, faster
41*da0073e9SAndroid Build Coastguard Worker      // implementations earlier in the function.
42*da0073e9SAndroid Build Coastguard Worker      AVX512_DO(foo, a, b);
43*da0073e9SAndroid Build Coastguard Worker      AVX2_DO(foo, a, b);
44*da0073e9SAndroid Build Coastguard Worker      AVX_DO(foo, a, b);
45*da0073e9SAndroid Build Coastguard Worker      BASE_DO(foo, a, b);
46*da0073e9SAndroid Build Coastguard Worker    }
47*da0073e9SAndroid Build Coastguard Worker 
48*da0073e9SAndroid Build Coastguard Worker */
49*da0073e9SAndroid Build Coastguard Worker // Details: this functionality basically covers the cases for both build time
50*da0073e9SAndroid Build Coastguard Worker // and run time architecture support.
51*da0073e9SAndroid Build Coastguard Worker //
52*da0073e9SAndroid Build Coastguard Worker // During build time:
53*da0073e9SAndroid Build Coastguard Worker //    The build system should provide flags CAFFE2_PERF_WITH_AVX512,
54*da0073e9SAndroid Build Coastguard Worker //    CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the
55*da0073e9SAndroid Build Coastguard Worker //    __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__, and __AVX__ flags the
56*da0073e9SAndroid Build Coastguard Worker //    compiler provides. Note that we do not use the compiler flags but rely on
57*da0073e9SAndroid Build Coastguard Worker //    the build system flags, because the common files (like foo.cc above) will
58*da0073e9SAndroid Build Coastguard Worker //    always be built without __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__
59*da0073e9SAndroid Build Coastguard Worker //    and __AVX__.
60*da0073e9SAndroid Build Coastguard Worker // During run time:
61*da0073e9SAndroid Build Coastguard Worker //    we use cpuinfo to identify cpu support and run the proper functions.
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker #pragma once
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker #if defined(CAFFE2_PERF_WITH_AVX512) || defined(CAFFE2_PERF_WITH_AVX2) \
66*da0073e9SAndroid Build Coastguard Worker      || defined(CAFFE2_PERF_WITH_AVX)
67*da0073e9SAndroid Build Coastguard Worker #include <cpuinfo.h>
68*da0073e9SAndroid Build Coastguard Worker #endif
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker // DO macros: these should be used in your entry function, similar to foo()
71*da0073e9SAndroid Build Coastguard Worker // above, that routes implementations based on CPU capability.
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker #define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__);
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker #ifdef CAFFE2_PERF_WITH_AVX512
76*da0073e9SAndroid Build Coastguard Worker #define AVX512_DO(funcname, ...)                                   \
77*da0073e9SAndroid Build Coastguard Worker   {                                                                \
78*da0073e9SAndroid Build Coastguard Worker     static const bool isDo = cpuinfo_initialize() &&               \
79*da0073e9SAndroid Build Coastguard Worker         cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512dq() && \
80*da0073e9SAndroid Build Coastguard Worker         cpuinfo_has_x86_avx512vl();                                \
81*da0073e9SAndroid Build Coastguard Worker     if (isDo) {                                                    \
82*da0073e9SAndroid Build Coastguard Worker       return funcname##__avx512(__VA_ARGS__);                      \
83*da0073e9SAndroid Build Coastguard Worker     }                                                              \
84*da0073e9SAndroid Build Coastguard Worker   }
85*da0073e9SAndroid Build Coastguard Worker #else // CAFFE2_PERF_WITH_AVX512
86*da0073e9SAndroid Build Coastguard Worker #define AVX512_DO(funcname, ...)
87*da0073e9SAndroid Build Coastguard Worker #endif // CAFFE2_PERF_WITH_AVX512
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker #ifdef CAFFE2_PERF_WITH_AVX2
90*da0073e9SAndroid Build Coastguard Worker #define AVX2_DO(funcname, ...)                                               \
91*da0073e9SAndroid Build Coastguard Worker   {                                                                          \
92*da0073e9SAndroid Build Coastguard Worker     static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx2(); \
93*da0073e9SAndroid Build Coastguard Worker     if (isDo) {                                                              \
94*da0073e9SAndroid Build Coastguard Worker       return funcname##__avx2(__VA_ARGS__);                                  \
95*da0073e9SAndroid Build Coastguard Worker     }                                                                        \
96*da0073e9SAndroid Build Coastguard Worker   }
97*da0073e9SAndroid Build Coastguard Worker #define AVX2_FMA_DO(funcname, ...)                                             \
98*da0073e9SAndroid Build Coastguard Worker   {                                                                            \
99*da0073e9SAndroid Build Coastguard Worker     static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx2() && \
100*da0073e9SAndroid Build Coastguard Worker         cpuinfo_has_x86_fma3();                                                \
101*da0073e9SAndroid Build Coastguard Worker     if (isDo) {                                                                \
102*da0073e9SAndroid Build Coastguard Worker       return funcname##__avx2_fma(__VA_ARGS__);                                \
103*da0073e9SAndroid Build Coastguard Worker     }                                                                          \
104*da0073e9SAndroid Build Coastguard Worker   }
105*da0073e9SAndroid Build Coastguard Worker #else // CAFFE2_PERF_WITH_AVX2
106*da0073e9SAndroid Build Coastguard Worker #define AVX2_DO(funcname, ...)
107*da0073e9SAndroid Build Coastguard Worker #define AVX2_FMA_DO(funcname, ...)
108*da0073e9SAndroid Build Coastguard Worker #endif // CAFFE2_PERF_WITH_AVX2
109*da0073e9SAndroid Build Coastguard Worker 
110*da0073e9SAndroid Build Coastguard Worker #ifdef CAFFE2_PERF_WITH_AVX
111*da0073e9SAndroid Build Coastguard Worker #define AVX_DO(funcname, ...)                                               \
112*da0073e9SAndroid Build Coastguard Worker   {                                                                         \
113*da0073e9SAndroid Build Coastguard Worker     static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx(); \
114*da0073e9SAndroid Build Coastguard Worker     if (isDo) {                                                             \
115*da0073e9SAndroid Build Coastguard Worker       return funcname##__avx(__VA_ARGS__);                                  \
116*da0073e9SAndroid Build Coastguard Worker     }                                                                       \
117*da0073e9SAndroid Build Coastguard Worker   }
118*da0073e9SAndroid Build Coastguard Worker #define AVX_F16C_DO(funcname, ...)                                            \
119*da0073e9SAndroid Build Coastguard Worker   {                                                                           \
120*da0073e9SAndroid Build Coastguard Worker     static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx() && \
121*da0073e9SAndroid Build Coastguard Worker         cpuinfo_has_x86_f16c();                                               \
122*da0073e9SAndroid Build Coastguard Worker     if (isDo) {                                                               \
123*da0073e9SAndroid Build Coastguard Worker       return funcname##__avx_f16c(__VA_ARGS__);                               \
124*da0073e9SAndroid Build Coastguard Worker     }                                                                         \
125*da0073e9SAndroid Build Coastguard Worker   }
126*da0073e9SAndroid Build Coastguard Worker #else // CAFFE2_PERF_WITH_AVX
127*da0073e9SAndroid Build Coastguard Worker #define AVX_DO(funcname, ...)
128*da0073e9SAndroid Build Coastguard Worker #define AVX_F16C_DO(funcname, ...)
129*da0073e9SAndroid Build Coastguard Worker #endif // CAFFE2_PERF_WITH_AVX
130