xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite-model-executor.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite-model-executor.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
20*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/register.h"
21*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/schema/schema_generated.h"
22*993b0882SAndroid Build Coastguard Worker 
23*993b0882SAndroid Build Coastguard Worker // Forward declaration of custom TensorFlow Lite ops for registration.
24*993b0882SAndroid Build Coastguard Worker namespace tflite {
25*993b0882SAndroid Build Coastguard Worker namespace ops {
26*993b0882SAndroid Build Coastguard Worker namespace builtin {
27*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_GELU();
28*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_ADD();
29*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_CONCATENATION();
30*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_CONV_2D();
31*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
32*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_AVERAGE_POOL_2D();
33*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_EQUAL();
34*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_FULLY_CONNECTED();
35*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_GREATER_EQUAL();
36*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_L2_NORMALIZATION();
37*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_MUL();
38*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_RESHAPE();
39*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_REDUCE_MAX();
40*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_REDUCE_MIN();
41*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_REDUCE_ANY();
42*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SOFTMAX();
43*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_GATHER();
44*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_GATHER_ND();
45*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_IF();
46*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_ROUND();
47*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_ZEROS_LIKE();
48*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TRANSPOSE();
49*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SUB();
50*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_DIV();
51*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_STRIDED_SLICE();
52*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_EXP();
53*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TOPK_V2();
54*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SLICE();
55*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SPLIT();
56*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_CAST();
57*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_MAXIMUM();
58*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_MINIMUM();
59*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_NEG();
60*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SLICE();
61*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_LOG();
62*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_LOGISTIC();
63*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SUM();
64*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_PACK();
65*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_DEQUANTIZE();
66*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_MEAN();
67*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_LESS();
68*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TILE();
69*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SQUARED_DIFFERENCE();
70*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_RSQRT();
71*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_LOG_SOFTMAX();
72*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_WHERE();
73*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_ONE_HOT();
74*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_POW();
75*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TANH();
76*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_UNIQUE();
77*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_REDUCE_PROD();
78*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SHAPE();
79*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_NOT_EQUAL();
80*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_CUMSUM();
81*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_EXPAND_DIMS();
82*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_FILL();
83*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_PADV2();
84*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_EMBEDDING_LOOKUP();
85*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_GREATER();
86*993b0882SAndroid Build Coastguard Worker }  // namespace builtin
87*993b0882SAndroid Build Coastguard Worker }  // namespace ops
88*993b0882SAndroid Build Coastguard Worker }  // namespace tflite
89*993b0882SAndroid Build Coastguard Worker 
90*993b0882SAndroid Build Coastguard Worker #ifdef TC3_WITH_ACTIONS_OPS
91*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/blacklist.h"
92*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/dist_diversification.h"
93*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/string_projection.h"
94*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/text_encoder.h"
95*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/text_encoder3s.h"
96*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/token_encoder.h"
97*993b0882SAndroid Build Coastguard Worker 
98*993b0882SAndroid Build Coastguard Worker namespace tflite {
99*993b0882SAndroid Build Coastguard Worker namespace ops {
100*993b0882SAndroid Build Coastguard Worker namespace custom {
101*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
102*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
103*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_RAGGED_RANGE();
104*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_RANDOM_UNIFORM();
105*993b0882SAndroid Build Coastguard Worker }  // namespace custom
106*993b0882SAndroid Build Coastguard Worker }  // namespace ops
107*993b0882SAndroid Build Coastguard Worker }  // namespace tflite
108*993b0882SAndroid Build Coastguard Worker 
RegisterSelectedOps(tflite::MutableOpResolver * resolver)109*993b0882SAndroid Build Coastguard Worker void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
110*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
111*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_ADD(),
112*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
113*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
114*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
115*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_CONCATENATION(),
116*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
117*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
118*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
119*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_CONV_2D(),
120*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
121*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/5);
122*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
123*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(),
124*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
125*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/6);
126*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
127*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_AVERAGE_POOL_2D(),
128*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
129*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/1);
130*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
131*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_EQUAL());
132*993b0882SAndroid Build Coastguard Worker 
133*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
134*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_FULLY_CONNECTED(),
135*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
136*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/9);
137*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER_EQUAL,
138*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_GREATER_EQUAL());
139*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
140*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_L2_NORMALIZATION(),
141*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
142*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
143*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
144*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_MUL());
145*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
146*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_RESHAPE());
147*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MAX,
148*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_REDUCE_MAX());
149*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MIN,
150*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_REDUCE_MIN());
151*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_ANY,
152*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_REDUCE_ANY());
153*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
154*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_SOFTMAX(),
155*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
156*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
157*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
158*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_GATHER(),
159*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
160*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
161*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_GATHER_ND,
162*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_GATHER_ND(),
163*993b0882SAndroid Build Coastguard Worker                        /*version=*/2);
164*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_IF,
165*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_IF()),
166*993b0882SAndroid Build Coastguard Worker       resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND,
167*993b0882SAndroid Build Coastguard Worker                            ::tflite::ops::builtin::Register_ROUND());
168*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_ZEROS_LIKE,
169*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_ZEROS_LIKE());
170*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
171*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_TRANSPOSE(),
172*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
173*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
174*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
175*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_SUB(),
176*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
177*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
178*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
179*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_DIV());
180*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
181*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_STRIDED_SLICE(),
182*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
183*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
184*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
185*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_EXP());
186*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
187*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_TOPK_V2(),
188*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
189*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
190*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
191*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_SLICE(),
192*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
193*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/3);
194*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
195*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_SPLIT(),
196*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
197*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/3);
198*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
199*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_CAST());
200*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
201*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_MAXIMUM(),
202*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
203*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
204*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
205*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_MINIMUM(),
206*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
207*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
208*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
209*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_NEG());
210*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
211*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_SLICE(),
212*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
213*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
214*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
215*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_LOG());
216*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_LOGISTIC,
217*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_LOGISTIC());
218*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
219*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_SUM());
220*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
221*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_PACK(),
222*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
223*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
224*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
225*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_DEQUANTIZE(),
226*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
227*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/2);
228*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
229*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_MEAN());
230*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_LESS,
231*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_LESS());
232*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_TILE,
233*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_TILE());
234*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_SQUARED_DIFFERENCE,
235*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_SQUARED_DIFFERENCE());
236*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_RSQRT,
237*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_RSQRT());
238*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_LOG_SOFTMAX,
239*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_LOG_SOFTMAX());
240*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_WHERE,
241*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_WHERE());
242*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_ONE_HOT,
243*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_ONE_HOT(),
244*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
245*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/1);
246*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_POW,
247*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_POW(),
248*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
249*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/1);
250*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_TANH,
251*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_TANH(),
252*993b0882SAndroid Build Coastguard Worker                        /*min_version=*/1,
253*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/1);
254*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_UNIQUE,
255*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_UNIQUE());
256*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_PROD,
257*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_REDUCE_PROD());
258*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
259*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_SHAPE());
260*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_NOT_EQUAL,
261*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_NOT_EQUAL());
262*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_CUMSUM,
263*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_CUMSUM());
264*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
265*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_EXPAND_DIMS());
266*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_FILL,
267*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_FILL());
268*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_PADV2,
269*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_PADV2());
270*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_EMBEDDING_LOOKUP,
271*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_EMBEDDING_LOOKUP(),
272*993b0882SAndroid Build Coastguard Worker                        /* min_version=*/1,
273*993b0882SAndroid Build Coastguard Worker                        /*max_version=*/3);
274*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER,
275*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_GREATER());
276*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(::tflite::BuiltinOperator_GELU,
277*993b0882SAndroid Build Coastguard Worker                        ::tflite::ops::builtin::Register_GELU());
278*993b0882SAndroid Build Coastguard Worker }
279*993b0882SAndroid Build Coastguard Worker #else
RegisterSelectedOps(tflite::MutableOpResolver * resolver)280*993b0882SAndroid Build Coastguard Worker void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
281*993b0882SAndroid Build Coastguard Worker   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
282*993b0882SAndroid Build Coastguard Worker                        tflite::ops::builtin::Register_FULLY_CONNECTED());
283*993b0882SAndroid Build Coastguard Worker }
284*993b0882SAndroid Build Coastguard Worker #endif  // TC3_WITH_ACTIONS_OPS
285*993b0882SAndroid Build Coastguard Worker 
286*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
287*993b0882SAndroid Build Coastguard Worker 
BuildOpResolver()288*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
289*993b0882SAndroid Build Coastguard Worker   return BuildOpResolver([](tflite::MutableOpResolver* mutable_resolver) {});
290*993b0882SAndroid Build Coastguard Worker }
291*993b0882SAndroid Build Coastguard Worker 
BuildOpResolver(const std::function<void (tflite::MutableOpResolver *)> & customize_fn)292*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::OpResolver> BuildOpResolver(
293*993b0882SAndroid Build Coastguard Worker     const std::function<void(tflite::MutableOpResolver*)>& customize_fn) {
294*993b0882SAndroid Build Coastguard Worker #ifdef TC3_USE_SELECTIVE_REGISTRATION
295*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<tflite::MutableOpResolver> resolver(
296*993b0882SAndroid Build Coastguard Worker       new tflite::MutableOpResolver);
297*993b0882SAndroid Build Coastguard Worker   RegisterSelectedOps(resolver.get());
298*993b0882SAndroid Build Coastguard Worker #else
299*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
300*993b0882SAndroid Build Coastguard Worker       new tflite::ops::builtin::BuiltinOpResolver);
301*993b0882SAndroid Build Coastguard Worker #endif
302*993b0882SAndroid Build Coastguard Worker #ifdef TC3_WITH_ACTIONS_OPS
303*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom("DistanceDiversification",
304*993b0882SAndroid Build Coastguard Worker                       tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
305*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom("TextEncoder",
306*993b0882SAndroid Build Coastguard Worker                       tflite::ops::custom::Register_TEXT_ENCODER());
307*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom("TextEncoder3S",
308*993b0882SAndroid Build Coastguard Worker                       tflite::ops::custom::Register_TEXT_ENCODER3S());
309*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom("TokenEncoder",
310*993b0882SAndroid Build Coastguard Worker                       tflite::ops::custom::Register_TOKEN_ENCODER());
311*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom(
312*993b0882SAndroid Build Coastguard Worker       "TFSentencepieceTokenizeOp",
313*993b0882SAndroid Build Coastguard Worker       ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
314*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom("RaggedRange",
315*993b0882SAndroid Build Coastguard Worker                       ::tflite::ops::custom::Register_RAGGED_RANGE());
316*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom(
317*993b0882SAndroid Build Coastguard Worker       "RaggedTensorToTensor",
318*993b0882SAndroid Build Coastguard Worker       ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
319*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom(
320*993b0882SAndroid Build Coastguard Worker       "STRING_PROJECTION",
321*993b0882SAndroid Build Coastguard Worker       ::tflite::ops::custom::libtextclassifier3::Register_STRING_PROJECTION());
322*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom(
323*993b0882SAndroid Build Coastguard Worker       "BLACKLIST",
324*993b0882SAndroid Build Coastguard Worker       ::tflite::ops::custom::libtextclassifier3::Register_BLACKLIST());
325*993b0882SAndroid Build Coastguard Worker   resolver->AddCustom("RandomUniform",
326*993b0882SAndroid Build Coastguard Worker                       ::tflite::ops::custom::Register_RANDOM_UNIFORM());
327*993b0882SAndroid Build Coastguard Worker #endif  // TC3_WITH_ACTIONS_OPS
328*993b0882SAndroid Build Coastguard Worker   customize_fn(resolver.get());
329*993b0882SAndroid Build Coastguard Worker   return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
330*993b0882SAndroid Build Coastguard Worker }
331*993b0882SAndroid Build Coastguard Worker 
TfLiteModelFromModelSpec(const tflite::Model * model_spec)332*993b0882SAndroid Build Coastguard Worker std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
333*993b0882SAndroid Build Coastguard Worker     const tflite::Model* model_spec) {
334*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<const tflite::FlatBufferModel> model(
335*993b0882SAndroid Build Coastguard Worker       tflite::FlatBufferModel::BuildFromModel(model_spec));
336*993b0882SAndroid Build Coastguard Worker   if (!model || !model->initialized()) {
337*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
338*993b0882SAndroid Build Coastguard Worker     return nullptr;
339*993b0882SAndroid Build Coastguard Worker   }
340*993b0882SAndroid Build Coastguard Worker   return model;
341*993b0882SAndroid Build Coastguard Worker }
342*993b0882SAndroid Build Coastguard Worker 
TfLiteModelFromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)343*993b0882SAndroid Build Coastguard Worker std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
344*993b0882SAndroid Build Coastguard Worker     const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
345*993b0882SAndroid Build Coastguard Worker   const tflite::Model* model =
346*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
347*993b0882SAndroid Build Coastguard Worker   flatbuffers::Verifier verifier(model_spec_buffer->data(),
348*993b0882SAndroid Build Coastguard Worker                                  model_spec_buffer->size());
349*993b0882SAndroid Build Coastguard Worker   if (!model->Verify(verifier)) {
350*993b0882SAndroid Build Coastguard Worker     return nullptr;
351*993b0882SAndroid Build Coastguard Worker   }
352*993b0882SAndroid Build Coastguard Worker   return TfLiteModelFromModelSpec(model);
353*993b0882SAndroid Build Coastguard Worker }
354*993b0882SAndroid Build Coastguard Worker 
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)355*993b0882SAndroid Build Coastguard Worker TfLiteModelExecutor::TfLiteModelExecutor(
356*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<const tflite::FlatBufferModel> model)
357*993b0882SAndroid Build Coastguard Worker     : model_(std::move(model)), resolver_(BuildOpResolver()) {}
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,std::unique_ptr<tflite::OpResolver> resolver)358*993b0882SAndroid Build Coastguard Worker TfLiteModelExecutor::TfLiteModelExecutor(
359*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<const tflite::FlatBufferModel> model,
360*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<tflite::OpResolver> resolver)
361*993b0882SAndroid Build Coastguard Worker     : model_(std::move(model)), resolver_(std::move(resolver)) {}
362*993b0882SAndroid Build Coastguard Worker 
CreateInterpreter() const363*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
364*993b0882SAndroid Build Coastguard Worker     const {
365*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<tflite::Interpreter> interpreter;
366*993b0882SAndroid Build Coastguard Worker   tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
367*993b0882SAndroid Build Coastguard Worker   return interpreter;
368*993b0882SAndroid Build Coastguard Worker }
369*993b0882SAndroid Build Coastguard Worker 
370*993b0882SAndroid Build Coastguard Worker template <>
SetInput(const int input_index,const std::vector<std::string> & input_data,tflite::Interpreter * interpreter) const371*993b0882SAndroid Build Coastguard Worker void TfLiteModelExecutor::SetInput(const int input_index,
372*993b0882SAndroid Build Coastguard Worker                                    const std::vector<std::string>& input_data,
373*993b0882SAndroid Build Coastguard Worker                                    tflite::Interpreter* interpreter) const {
374*993b0882SAndroid Build Coastguard Worker   tflite::DynamicBuffer buf;
375*993b0882SAndroid Build Coastguard Worker   for (const std::string& s : input_data) {
376*993b0882SAndroid Build Coastguard Worker     buf.AddString(s.data(), s.length());
377*993b0882SAndroid Build Coastguard Worker   }
378*993b0882SAndroid Build Coastguard Worker   buf.WriteToTensorAsVector(
379*993b0882SAndroid Build Coastguard Worker       interpreter->tensor(interpreter->inputs()[input_index]));
380*993b0882SAndroid Build Coastguard Worker }
381*993b0882SAndroid Build Coastguard Worker 
382*993b0882SAndroid Build Coastguard Worker template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const383*993b0882SAndroid Build Coastguard Worker std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
384*993b0882SAndroid Build Coastguard Worker     const int output_index, const tflite::Interpreter* interpreter) const {
385*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor* output_tensor =
386*993b0882SAndroid Build Coastguard Worker       interpreter->tensor(interpreter->outputs()[output_index]);
387*993b0882SAndroid Build Coastguard Worker   const int num_strings = tflite::GetStringCount(output_tensor);
388*993b0882SAndroid Build Coastguard Worker   std::vector<tflite::StringRef> output(num_strings);
389*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_strings; i++) {
390*993b0882SAndroid Build Coastguard Worker     output[i] = tflite::GetString(output_tensor, i);
391*993b0882SAndroid Build Coastguard Worker   }
392*993b0882SAndroid Build Coastguard Worker   return output;
393*993b0882SAndroid Build Coastguard Worker }
394*993b0882SAndroid Build Coastguard Worker 
395*993b0882SAndroid Build Coastguard Worker template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const396*993b0882SAndroid Build Coastguard Worker std::vector<std::string> TfLiteModelExecutor::Output(
397*993b0882SAndroid Build Coastguard Worker     const int output_index, const tflite::Interpreter* interpreter) const {
398*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> output;
399*993b0882SAndroid Build Coastguard Worker   for (const tflite::StringRef& s :
400*993b0882SAndroid Build Coastguard Worker        Output<tflite::StringRef>(output_index, interpreter)) {
401*993b0882SAndroid Build Coastguard Worker     output.push_back(std::string(s.str, s.len));
402*993b0882SAndroid Build Coastguard Worker   }
403*993b0882SAndroid Build Coastguard Worker   return output;
404*993b0882SAndroid Build Coastguard Worker }
405*993b0882SAndroid Build Coastguard Worker 
406*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
407