xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/CL/UNIT/MLGOHeuristics.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "src/runtime/CL/mlgo/MLGOHeuristics.h"
25*c217d954SCole Faust #include "src/runtime/CL/mlgo/Utils.h"
26*c217d954SCole Faust #include "tests/framework/Asserts.h"
27*c217d954SCole Faust #include "tests/framework/Macros.h"
28*c217d954SCole Faust 
29*c217d954SCole Faust using namespace arm_compute::mlgo;
30*c217d954SCole Faust 
31*c217d954SCole Faust namespace arm_compute
32*c217d954SCole Faust {
33*c217d954SCole Faust namespace test
34*c217d954SCole Faust {
35*c217d954SCole Faust namespace validation
36*c217d954SCole Faust {
37*c217d954SCole Faust TEST_SUITE(CL)
TEST_SUITE(UNIT)38*c217d954SCole Faust TEST_SUITE(UNIT)
39*c217d954SCole Faust TEST_SUITE(MLGOHeuristics)
40*c217d954SCole Faust TEST_CASE(CorrectDotMLGOShouldLoadCorrectly, framework::DatasetMode::ALL)
41*c217d954SCole Faust {
42*c217d954SCole Faust     std::string       mlgo_str = R"_(
43*c217d954SCole Faust 
44*c217d954SCole Faust         <header>
45*c217d954SCole Faust 
46*c217d954SCole Faust         gemm-version, [1,2,1]
47*c217d954SCole Faust         ip-type,gpu
48*c217d954SCole Faust         </header>
49*c217d954SCole Faust         <heuristics-table>
50*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
51*c217d954SCole Faust 
52*c217d954SCole Faust         1, g71 , 8, f16, best-performance, static, gemm-config-reshaped-only-rhs, [m,n,k,n]
53*c217d954SCole Faust         2, g76 , 8, f16, best-performance, static, gemm-config-reshaped, [m,n,k,n]
54*c217d954SCole Faust         </heuristics-table>
55*c217d954SCole Faust         <heuristic, 0>
56*c217d954SCole Faust         b , 0, var, m, ==, num, 10., 1, 2
57*c217d954SCole Faust         l , 1, gemm-type, reshaped
58*c217d954SCole Faust         b , 2, var, r_mn, >=, num, 2., 3, 6
59*c217d954SCole Faust 
60*c217d954SCole Faust         b , 3, var, n, >=, num, 200., 4, 5
61*c217d954SCole Faust         l, 4,                          gemm-type, reshaped-only-rhs
62*c217d954SCole Faust         l , 5, gemm-type, reshaped
63*c217d954SCole Faust         l , 6, gemm-type, reshaped-only-rhs
64*c217d954SCole Faust         </heuristic>
65*c217d954SCole Faust         <heuristic, 1>
66*c217d954SCole Faust         b ,0,var, n, >, num, 100., 1, 4
67*c217d954SCole Faust         b ,1,var, r_mnk, <=, num, 20., 2, 3
68*c217d954SCole Faust 
69*c217d954SCole Faust 
70*c217d954SCole Faust         l ,2,gemm-config-reshaped-only-rhs, [4, 4,4,2,1,0,1]
71*c217d954SCole Faust         l ,3,gemm-config-reshaped-only-rhs,[ 2, 2,4,2,1,1, 1 ]
72*c217d954SCole Faust         b ,4,var, n, >=, num, 199.12, 5, 6
73*c217d954SCole Faust         l ,5,gemm-config-reshaped-only-rhs, [1, 4,3,4,0,0,0]
74*c217d954SCole Faust         l ,6,gemm-config-reshaped-only-rhs, [5, 4,4,5,1,1,0]
75*c217d954SCole Faust         </heuristic>
76*c217d954SCole Faust 
77*c217d954SCole Faust         <heuristic, 2>
78*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
79*c217d954SCole Faust 
80*c217d954SCole Faust         </heuristic>
81*c217d954SCole Faust 
82*c217d954SCole Faust     )_";
83*c217d954SCole Faust     std::stringstream ss(mlgo_str);
84*c217d954SCole Faust     MLGOHeuristics    heuristics;
85*c217d954SCole Faust     heuristics.reload_from_stream(ss);
86*c217d954SCole Faust 
87*c217d954SCole Faust     ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 10, 1024, 20, 1 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS);
88*c217d954SCole Faust     ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 201, 5, 1 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS);
89*c217d954SCole Faust     ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 200, 199, 16 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS);
90*c217d954SCole Faust     ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 199, 512, 4 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS);
91*c217d954SCole Faust 
92*c217d954SCole Faust     ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }),
93*c217d954SCole Faust                        framework::LogLevel::ERRORS);
94*c217d954SCole Faust     ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }),
95*c217d954SCole Faust                        framework::LogLevel::ERRORS);
96*c217d954SCole Faust     ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 128, 101, 20, 1 }).second == GEMMConfigReshapedOnlyRHS{ 2, 2, 4, 2, true, true, true }),
97*c217d954SCole Faust                        framework::LogLevel::ERRORS);
98*c217d954SCole Faust     ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }),
99*c217d954SCole Faust                        framework::LogLevel::ERRORS);
100*c217d954SCole Faust     ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }),
101*c217d954SCole Faust                        framework::LogLevel::ERRORS);
102*c217d954SCole Faust 
103*c217d954SCole Faust     ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 100, 100, 20, 32 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }),
104*c217d954SCole Faust                        framework::LogLevel::ERRORS);
105*c217d954SCole Faust     ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 128, 512, 1024, 1 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }),
106*c217d954SCole Faust                        framework::LogLevel::ERRORS);
107*c217d954SCole Faust }
108*c217d954SCole Faust 
TEST_CASE(InvalidDotmlgoSyntaxShouldReturnInvalidStatus,framework::DatasetMode::ALL)109*c217d954SCole Faust TEST_CASE(InvalidDotmlgoSyntaxShouldReturnInvalidStatus, framework::DatasetMode::ALL)
110*c217d954SCole Faust {
111*c217d954SCole Faust     std::string       mlgo_str = R"_(
112*c217d954SCole Faust         <header>
113*c217d954SCole Faust         gemm-version, [1,2,1]
114*c217d954SCole Faust         ip-type,pu
115*c217d954SCole Faust         </header>
116*c217d954SCole Faust         <heuristics-table>
117*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
118*c217d954SCole Faust 
119*c217d954SCole Faust         </heurist
120*c217d954SCole Faust         <heuristic, 0>
121*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
122*c217d954SCole Faust         </heuristic>
123*c217d954SCole Faust     )_";
124*c217d954SCole Faust     std::stringstream ss(mlgo_str);
125*c217d954SCole Faust     MLGOHeuristics    heuristics;
126*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
127*c217d954SCole Faust }
128*c217d954SCole Faust 
129*c217d954SCole Faust TEST_SUITE(InvalidDotmlgoSemanticsShouldReturnInvalidStatus)
130*c217d954SCole Faust // If the semantics errors are local to some trees instead of the entire heuristics, an alternative is to simply
131*c217d954SCole Faust // ignore/remove those invalid trees. However the reason why we choose to throw, thus invalidating the entire
132*c217d954SCole Faust // heuristics is that if there are some invalid trees, the quality of the dotmlgo is called into question even if
133*c217d954SCole Faust // the rest of the trees are semantically valid, and they could severely degrade the performance of GEMM. Therefore
134*c217d954SCole Faust // this "all or nothing" approach when it comes to dotmlgo correctness is safer and more defensive.
135*c217d954SCole Faust 
136*c217d954SCole Faust // Also note that the semantic error of the tree only refers to those that obstruct its evaluation and thus query,
137*c217d954SCole Faust // (e.g. invalid tree structure, unsupported features etc.) instead of those affecting the desired outcome
138*c217d954SCole Faust // (usually in terms of final GEMM performance, e.g. the effectiveness of the decision tree)
139*c217d954SCole Faust 
140*c217d954SCole Faust // In the future we might want to check the content of the exceptions as well. But right now it suffices to only
141*c217d954SCole Faust // know that it throws exactly when it needs to.
TEST_CASE(MismatchesBetweenHeuristicsTableEntriesAndHeuristicTrees,framework::DatasetMode::ALL)142*c217d954SCole Faust TEST_CASE(MismatchesBetweenHeuristicsTableEntriesAndHeuristicTrees, framework::DatasetMode::ALL)
143*c217d954SCole Faust {
144*c217d954SCole Faust     {
145*c217d954SCole Faust         // Mismatching number of entries 1
146*c217d954SCole Faust         std::string       mlgo_str = R"_(
147*c217d954SCole Faust             <header>
148*c217d954SCole Faust             gemm-version, [1,2,1]
149*c217d954SCole Faust             ip-type,gpu
150*c217d954SCole Faust             </header>
151*c217d954SCole Faust             <heuristics-table>
152*c217d954SCole Faust 
153*c217d954SCole Faust             0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
154*c217d954SCole Faust 
155*c217d954SCole Faust             </heuristics-table>
156*c217d954SCole Faust         )_";
157*c217d954SCole Faust         std::stringstream ss(mlgo_str);
158*c217d954SCole Faust         MLGOHeuristics    heuristics;
159*c217d954SCole Faust         // NOTE: This case might throw an internal error as the tree inserted by the heuristics-table cannot not be checked
160*c217d954SCole Faust         ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
161*c217d954SCole Faust     }
162*c217d954SCole Faust 
163*c217d954SCole Faust     {
164*c217d954SCole Faust         // Mismatching number of entries 2
165*c217d954SCole Faust         std::string       mlgo_str = R"_(
166*c217d954SCole Faust             <header>
167*c217d954SCole Faust             gemm-version, [1,2,1]
168*c217d954SCole Faust             ip-type,gpu
169*c217d954SCole Faust             </header>
170*c217d954SCole Faust             <heuristics-table>
171*c217d954SCole Faust             </heuristics-table>
172*c217d954SCole Faust             <heuristic, 1>
173*c217d954SCole Faust             l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
174*c217d954SCole Faust             </heuristic>
175*c217d954SCole Faust         )_";
176*c217d954SCole Faust         std::stringstream ss(mlgo_str);
177*c217d954SCole Faust         MLGOHeuristics    heuristics;
178*c217d954SCole Faust         ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
179*c217d954SCole Faust     }
180*c217d954SCole Faust 
181*c217d954SCole Faust     {
182*c217d954SCole Faust         // Mismatching info
183*c217d954SCole Faust         std::string       mlgo_str = R"_(
184*c217d954SCole Faust             <header>
185*c217d954SCole Faust             gemm-version, [1,2,1]
186*c217d954SCole Faust             ip-type,gpu
187*c217d954SCole Faust             </header>
188*c217d954SCole Faust             <heuristics-table>
189*c217d954SCole Faust             0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
190*c217d954SCole Faust             </heuristics-table>
191*c217d954SCole Faust             <heuristic, 0>
192*c217d954SCole Faust             l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
193*c217d954SCole Faust             </heuristic>
194*c217d954SCole Faust         )_";
195*c217d954SCole Faust         std::stringstream ss(mlgo_str);
196*c217d954SCole Faust         MLGOHeuristics    heuristics;
197*c217d954SCole Faust         ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
198*c217d954SCole Faust     }
199*c217d954SCole Faust }
200*c217d954SCole Faust 
TEST_CASE(RepeatedHeuristicsTableEntriesId,framework::DatasetMode::ALL)201*c217d954SCole Faust TEST_CASE(RepeatedHeuristicsTableEntriesId, framework::DatasetMode::ALL)
202*c217d954SCole Faust {
203*c217d954SCole Faust     std::string       mlgo_str = R"_(
204*c217d954SCole Faust         <header>
205*c217d954SCole Faust         gemm-version, [1,2,1]
206*c217d954SCole Faust         ip-type,gpu
207*c217d954SCole Faust         </header>
208*c217d954SCole Faust         <heuristics-table>
209*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
210*c217d954SCole Faust         0, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
211*c217d954SCole Faust         </heuristics-table>
212*c217d954SCole Faust         <heuristic, 0>
213*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
214*c217d954SCole Faust         </heuristic>
215*c217d954SCole Faust         <heuristic, 1>
216*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
217*c217d954SCole Faust         </heuristic>
218*c217d954SCole Faust     )_";
219*c217d954SCole Faust     std::stringstream ss(mlgo_str);
220*c217d954SCole Faust     MLGOHeuristics    heuristics;
221*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
222*c217d954SCole Faust }
223*c217d954SCole Faust 
TEST_CASE(RepeatedHeuristicsTableEntriesIndex,framework::DatasetMode::ALL)224*c217d954SCole Faust TEST_CASE(RepeatedHeuristicsTableEntriesIndex, framework::DatasetMode::ALL)
225*c217d954SCole Faust {
226*c217d954SCole Faust     std::string       mlgo_str = R"_(
227*c217d954SCole Faust         <header>
228*c217d954SCole Faust         gemm-version, [1,2,1]
229*c217d954SCole Faust         ip-type,gpu
230*c217d954SCole Faust         </header>
231*c217d954SCole Faust         <heuristics-table>
232*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
233*c217d954SCole Faust         1, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
234*c217d954SCole Faust         </heuristics-table>
235*c217d954SCole Faust         <heuristic, 0>
236*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
237*c217d954SCole Faust         </heuristic>
238*c217d954SCole Faust         <heuristic, 1>
239*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
240*c217d954SCole Faust         </heuristic>
241*c217d954SCole Faust     )_";
242*c217d954SCole Faust     std::stringstream ss(mlgo_str);
243*c217d954SCole Faust     MLGOHeuristics    heuristics;
244*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
245*c217d954SCole Faust }
246*c217d954SCole Faust 
TEST_CASE(RepeatedHeuristicTreesId,framework::DatasetMode::ALL)247*c217d954SCole Faust TEST_CASE(RepeatedHeuristicTreesId, framework::DatasetMode::ALL)
248*c217d954SCole Faust {
249*c217d954SCole Faust     std::string       mlgo_str = R"_(
250*c217d954SCole Faust         <header>
251*c217d954SCole Faust         gemm-version, [1,2,1]
252*c217d954SCole Faust         ip-type,gpu
253*c217d954SCole Faust         </header>
254*c217d954SCole Faust         <heuristics-table>
255*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
256*c217d954SCole Faust         1, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
257*c217d954SCole Faust         </heuristics-table>
258*c217d954SCole Faust         <heuristic, 0>
259*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
260*c217d954SCole Faust         </heuristic>
261*c217d954SCole Faust         <heuristic, 0>
262*c217d954SCole Faust         l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
263*c217d954SCole Faust         </heuristic>
264*c217d954SCole Faust     )_";
265*c217d954SCole Faust     std::stringstream ss(mlgo_str);
266*c217d954SCole Faust     MLGOHeuristics    heuristics;
267*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
268*c217d954SCole Faust }
TEST_CASE(EmptyTree,framework::DatasetMode::ALL)269*c217d954SCole Faust TEST_CASE(EmptyTree, framework::DatasetMode::ALL)
270*c217d954SCole Faust {
271*c217d954SCole Faust     std::string       mlgo_str = R"_(
272*c217d954SCole Faust         <header>
273*c217d954SCole Faust         gemm-version, [1,2,1]
274*c217d954SCole Faust         ip-type,gpu
275*c217d954SCole Faust         </header>
276*c217d954SCole Faust         <heuristics-table>
277*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
278*c217d954SCole Faust         </heuristics-table>
279*c217d954SCole Faust         <heuristic, 0>
280*c217d954SCole Faust         </heuristic>
281*c217d954SCole Faust     )_";
282*c217d954SCole Faust     std::stringstream ss(mlgo_str);
283*c217d954SCole Faust     MLGOHeuristics    heuristics;
284*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
285*c217d954SCole Faust }
286*c217d954SCole Faust 
TEST_CASE(InvalidTreeMissingRoot,framework::DatasetMode::ALL)287*c217d954SCole Faust TEST_CASE(InvalidTreeMissingRoot, framework::DatasetMode::ALL)
288*c217d954SCole Faust {
289*c217d954SCole Faust     std::string       mlgo_str = R"_(
290*c217d954SCole Faust         <header>
291*c217d954SCole Faust         gemm-version, [1,2,1]
292*c217d954SCole Faust         ip-type,gpu
293*c217d954SCole Faust         </header>
294*c217d954SCole Faust         <heuristics-table>
295*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
296*c217d954SCole Faust         </heuristics-table>
297*c217d954SCole Faust         <heuristic, 0>
298*c217d954SCole Faust         b ,2, var, m, ==, num, 10., 3, 4
299*c217d954SCole Faust         l ,3,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
300*c217d954SCole Faust         l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
301*c217d954SCole Faust         </heuristic>
302*c217d954SCole Faust     )_";
303*c217d954SCole Faust     std::stringstream ss(mlgo_str);
304*c217d954SCole Faust     MLGOHeuristics    heuristics;
305*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
306*c217d954SCole Faust }
TEST_CASE(InvalidTreeMissingNodes,framework::DatasetMode::ALL)307*c217d954SCole Faust TEST_CASE(InvalidTreeMissingNodes, framework::DatasetMode::ALL)
308*c217d954SCole Faust {
309*c217d954SCole Faust     std::string       mlgo_str = R"_(
310*c217d954SCole Faust         <header>
311*c217d954SCole Faust         gemm-version, [1,2,1]
312*c217d954SCole Faust         ip-type,gpu
313*c217d954SCole Faust         </header>
314*c217d954SCole Faust         <heuristics-table>
315*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
316*c217d954SCole Faust         </heuristics-table>
317*c217d954SCole Faust         <heuristic, 0>
318*c217d954SCole Faust         b ,0, var, m, ==, num, 10., 1, 2
319*c217d954SCole Faust         l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
320*c217d954SCole Faust         </heuristic>
321*c217d954SCole Faust     )_";
322*c217d954SCole Faust     std::stringstream ss(mlgo_str);
323*c217d954SCole Faust     MLGOHeuristics    heuristics;
324*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
325*c217d954SCole Faust }
TEST_CASE(InvalidTreeRepeatedNodeIds,framework::DatasetMode::ALL)326*c217d954SCole Faust TEST_CASE(InvalidTreeRepeatedNodeIds, framework::DatasetMode::ALL)
327*c217d954SCole Faust {
328*c217d954SCole Faust     std::string       mlgo_str = R"_(
329*c217d954SCole Faust         <header>
330*c217d954SCole Faust         gemm-version, [1,2,1]
331*c217d954SCole Faust         ip-type,gpu
332*c217d954SCole Faust         </header>
333*c217d954SCole Faust         <heuristics-table>
334*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
335*c217d954SCole Faust         </heuristics-table>
336*c217d954SCole Faust         <heuristic, 0>
337*c217d954SCole Faust         b ,0, var, m, ==, num, 10., 1, 2
338*c217d954SCole Faust         l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
339*c217d954SCole Faust         l ,1,gemm-config-reshaped,[1,2,4,2,8,1,0,1,0]
340*c217d954SCole Faust         l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
341*c217d954SCole Faust         </heuristic>
342*c217d954SCole Faust     )_";
343*c217d954SCole Faust     std::stringstream ss(mlgo_str);
344*c217d954SCole Faust     MLGOHeuristics    heuristics;
345*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
346*c217d954SCole Faust }
TEST_CASE(InvalidTreeDisjointNodes,framework::DatasetMode::ALL)347*c217d954SCole Faust TEST_CASE(InvalidTreeDisjointNodes, framework::DatasetMode::ALL)
348*c217d954SCole Faust {
349*c217d954SCole Faust     std::string       mlgo_str = R"_(
350*c217d954SCole Faust         <header>
351*c217d954SCole Faust         gemm-version, [1,2,1]
352*c217d954SCole Faust         ip-type,gpu
353*c217d954SCole Faust         </header>
354*c217d954SCole Faust         <heuristics-table>
355*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
356*c217d954SCole Faust         </heuristics-table>
357*c217d954SCole Faust         <heuristic, 0>
358*c217d954SCole Faust         b ,0, var, m, ==, num, 10., 1, 2
359*c217d954SCole Faust         l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
360*c217d954SCole Faust         l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
361*c217d954SCole Faust 
362*c217d954SCole Faust         b ,4, var, n, ==, num, 10., 5, 6
363*c217d954SCole Faust         l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
364*c217d954SCole Faust         l ,6,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
365*c217d954SCole Faust 
366*c217d954SCole Faust         l ,7,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
367*c217d954SCole Faust         </heuristic>
368*c217d954SCole Faust     )_";
369*c217d954SCole Faust     std::stringstream ss(mlgo_str);
370*c217d954SCole Faust     MLGOHeuristics    heuristics;
371*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
372*c217d954SCole Faust }
TEST_CASE(InvalidTreeLoop,framework::DatasetMode::ALL)373*c217d954SCole Faust TEST_CASE(InvalidTreeLoop, framework::DatasetMode::ALL)
374*c217d954SCole Faust {
375*c217d954SCole Faust     std::string       mlgo_str = R"_(
376*c217d954SCole Faust         <header>
377*c217d954SCole Faust         gemm-version, [1,2,1]
378*c217d954SCole Faust         ip-type,gpu
379*c217d954SCole Faust         </header>
380*c217d954SCole Faust         <heuristics-table>
381*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
382*c217d954SCole Faust         </heuristics-table>
383*c217d954SCole Faust         <heuristic, 0>
384*c217d954SCole Faust         b ,0, var, m, ==, num, 10., 0, 1
385*c217d954SCole Faust         l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
386*c217d954SCole Faust         </heuristic>
387*c217d954SCole Faust     )_";
388*c217d954SCole Faust     std::stringstream ss(mlgo_str);
389*c217d954SCole Faust     MLGOHeuristics    heuristics;
390*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
391*c217d954SCole Faust }
TEST_CASE(InvalidTreeCycle,framework::DatasetMode::ALL)392*c217d954SCole Faust TEST_CASE(InvalidTreeCycle, framework::DatasetMode::ALL)
393*c217d954SCole Faust {
394*c217d954SCole Faust     std::string       mlgo_str = R"_(
395*c217d954SCole Faust         <header>
396*c217d954SCole Faust         gemm-version, [1,2,1]
397*c217d954SCole Faust         ip-type,gpu
398*c217d954SCole Faust         </header>
399*c217d954SCole Faust         <heuristics-table>
400*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
401*c217d954SCole Faust         </heuristics-table>
402*c217d954SCole Faust         <heuristic, 0>
403*c217d954SCole Faust         b ,0, var, m, ==, num, 10., 1, 5
404*c217d954SCole Faust         b ,1, var, n, ==, num, 10., 2, 3
405*c217d954SCole Faust         l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
406*c217d954SCole Faust         b ,3, var, k, ==, num, 10., 0, 4
407*c217d954SCole Faust         l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
408*c217d954SCole Faust         l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
409*c217d954SCole Faust         </heuristic>
410*c217d954SCole Faust     )_";
411*c217d954SCole Faust     std::stringstream ss(mlgo_str);
412*c217d954SCole Faust     MLGOHeuristics    heuristics;
413*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
414*c217d954SCole Faust }
TEST_CASE(InvalidTreeInvalidFeatures,framework::DatasetMode::ALL)415*c217d954SCole Faust TEST_CASE(InvalidTreeInvalidFeatures, framework::DatasetMode::ALL)
416*c217d954SCole Faust {
417*c217d954SCole Faust     std::string       mlgo_str = R"_(
418*c217d954SCole Faust         <header>
419*c217d954SCole Faust         gemm-version, [1,2,1]
420*c217d954SCole Faust         ip-type,gpu
421*c217d954SCole Faust         </header>
422*c217d954SCole Faust         <heuristics-table>
423*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
424*c217d954SCole Faust         </heuristics-table>
425*c217d954SCole Faust         <heuristic, 0>
426*c217d954SCole Faust         b ,0, var, magic_feature, ==, num, 10., 1, 2
427*c217d954SCole Faust         l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
428*c217d954SCole Faust         l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
429*c217d954SCole Faust         </heuristic>
430*c217d954SCole Faust     )_";
431*c217d954SCole Faust     std::stringstream ss(mlgo_str);
432*c217d954SCole Faust     MLGOHeuristics    heuristics;
433*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
434*c217d954SCole Faust }
435*c217d954SCole Faust TEST_SUITE_END() // InvalidDotmlgoSemanticsShouldReturnInvalidStatus
436*c217d954SCole Faust 
TEST_CASE(InvalidUsageOfHeuristicsShouldReturnInvalidStatus,framework::DatasetMode::ALL)437*c217d954SCole Faust TEST_CASE(InvalidUsageOfHeuristicsShouldReturnInvalidStatus, framework::DatasetMode::ALL)
438*c217d954SCole Faust {
439*c217d954SCole Faust     std::string       mlgo_str = R"_(
440*c217d954SCole Faust         <header>
441*c217d954SCole Faust         gemm-version, [1,2,1]
442*c217d954SCole Faust         ip-type,gpu
443*c217d954SCole Faust         </header>
444*c217d954SCole Faust         <heuristics-table>
445*c217d954SCole Faust         0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
446*c217d954SCole Faust         </heuristics-table>
447*c217d954SCole Faust         <heuristic, 0>
448*c217d954SCole Faust         b , 0, var, m, ==, num, 10., 1, 2
449*c217d954SCole Faust         l , 1, gemm-type, reshaped
450*c217d954SCole Faust         b , 2, var, r_mn, >=, num, 2., 3, 6
451*c217d954SCole Faust         b , 3, var, n, >=, num, 200., 4, 5
452*c217d954SCole Faust         l , 4, gemm-type, reshaped-only-rhs
453*c217d954SCole Faust         l , 5, gemm-type, reshaped
454*c217d954SCole Faust         l , 6, gemm-type, reshaped-only-rhs
455*c217d954SCole Faust         </heuristic>
456*c217d954SCole Faust     )_";
457*c217d954SCole Faust     std::stringstream ss(mlgo_str);
458*c217d954SCole Faust     MLGOHeuristics    heuristics;
459*c217d954SCole Faust     ARM_COMPUTE_EXPECT(heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
460*c217d954SCole Faust 
461*c217d954SCole Faust     // Querying unavailable heuristic type should return invalid Status
462*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
463*c217d954SCole Faust     // Querying unavailable ip target should return invalid Status
464*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.query_gemm_type(Query{ "g77", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
465*c217d954SCole Faust     // Querying unavailable data type should return invalid Status
466*c217d954SCole Faust     ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g76", DataType::QASYMM8, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
467*c217d954SCole Faust }
468*c217d954SCole Faust TEST_SUITE_END() // MLGOHeuristics
469*c217d954SCole Faust TEST_SUITE_END() // UNIT
470*c217d954SCole Faust TEST_SUITE_END() // CL
471*c217d954SCole Faust } // namespace validation
472*c217d954SCole Faust } // namespace test
473*c217d954SCole Faust } // namespace arm_compute
474