xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/signature_runner.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
16 #define TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
17 
18 #include <cstddef>
19 #include <cstdint>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/core/subgraph.h"
25 #include "tensorflow/lite/internal/signature_def.h"
26 
27 namespace tflite {
28 class Interpreter;  // Class for friend declarations.
29 class SignatureRunnerJNIHelper;  // Class for friend declarations.
30 class TensorHandle;              // Class for friend declarations.
31 class SignatureRunnerHelper;     // Class for friend declarations.
32 
33 /// WARNING: Experimental interface, subject to change
34 ///
35 /// SignatureRunner class for running TFLite models using SignatureDef.
36 ///
37 /// Usage:
38 ///
39 /// <pre><code>
40 /// // Create model from file. Note that the model instance must outlive the
41 /// // interpreter instance.
42 /// auto model = tflite::FlatBufferModel::BuildFromFile(...);
43 /// if (model == nullptr) {
44 ///   // Return error.
45 /// }
46 ///
47 /// // Create an Interpreter with an InterpreterBuilder.
48 /// std::unique_ptr<tflite::Interpreter> interpreter;
49 /// tflite::ops::builtin::BuiltinOpResolver resolver;
50 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
51 ///   // Return failure.
52 /// }
53 ///
54 /// // Get the list of signatures and check it.
55 /// auto signature_defs = interpreter->signature_def_names();
56 /// if (signature_defs.empty()) {
57 ///   // Return error.
58 /// }
59 ///
60 /// // Get pointer to the SignatureRunner instance corresponding to a signature.
61 /// // Note that the pointed SignatureRunner instance has lifetime same as the
62 /// // Interpreter instance.
63 /// tflite::SignatureRunner* runner =
64 ///                interpreter->GetSignatureRunner(signature_defs[0]->c_str());
65 /// if (runner == nullptr) {
66 ///   // Return error.
67 /// }
68 /// if (runner->AllocateTensors() != kTfLiteOk) {
69 ///   // Return failure.
70 /// }
71 ///
72 /// // Set input data. In this example, the input tensor has float type.
73 /// float* input = runner->input_tensor(0)->data.f;
74 /// for (int i = 0; i < input_size; i++) {
75 ///   input[i] = ...;
76 //  }
77 /// runner->Invoke();
78 /// </code></pre>
79 ///
80 /// WARNING: This class is *not* thread-safe. The client is responsible for
81 /// ensuring serialized interaction to avoid data races and undefined behavior.
82 ///
83 /// SignatureRunner and Interpreter share the same underlying data. Calling
84 /// methods on an Interpreter object will affect the state in corresponding
85 /// SignatureRunner objects. Therefore, it is recommended not to call other
86 /// Interpreter methods after calling GetSignatureRunner to create
87 /// SignatureRunner instances.
88 class SignatureRunner {
89  public:
90   /// Returns the key for the corresponding signature.
signature_key()91   const std::string& signature_key() { return signature_def_->signature_key; }
92 
93   /// Returns the number of inputs.
input_size()94   size_t input_size() const { return subgraph_->inputs().size(); }
95 
96   /// Returns the number of outputs.
output_size()97   size_t output_size() const { return subgraph_->outputs().size(); }
98 
99   /// Read-only access to list of signature input names.
input_names()100   const std::vector<const char*>& input_names() { return input_names_; }
101 
102   /// Read-only access to list of signature output names.
output_names()103   const std::vector<const char*>& output_names() { return output_names_; }
104 
105   /// Returns the input tensor identified by 'input_name' in the
106   /// given signature. Returns nullptr if the given name is not valid.
107   TfLiteTensor* input_tensor(const char* input_name);
108 
109   /// Returns the output tensor identified by 'output_name' in the
110   /// given signature. Returns nullptr if the given name is not valid.
111   const TfLiteTensor* output_tensor(const char* output_name) const;
112 
113   /// Change a dimensionality of a given tensor. Note, this is only acceptable
114   /// for tensors that are inputs.
115   /// Returns status of failure or success. Note that this doesn't actually
116   /// resize any existing buffers. A call to AllocateTensors() is required to
117   /// change the tensor input buffer.
118   TfLiteStatus ResizeInputTensor(const char* input_name,
119                                  const std::vector<int>& new_size);
120 
121   /// Change the dimensionality of a given tensor. This is only acceptable for
122   /// tensor indices that are inputs or variables. Only unknown dimensions can
123   /// be resized with this function. Unknown dimensions are indicated as `-1` in
124   /// the `dims_signature` attribute of a TfLiteTensor.
125   /// Returns status of failure or success. Note that this doesn't actually
126   /// resize any existing buffers. A call to AllocateTensors() is required to
127   /// change the tensor input buffer.
128   TfLiteStatus ResizeInputTensorStrict(const char* input_name,
129                                        const std::vector<int>& new_size);
130 
131   /// Updates allocations for all tensors, related to the given signature.
AllocateTensors()132   TfLiteStatus AllocateTensors() { return subgraph_->AllocateTensors(); }
133 
134   /// Invokes the signature runner (run the graph identified by the given
135   /// signature in dependency order).
136   TfLiteStatus Invoke();
137 
138  private:
139   // The life cycle of SignatureRunner depends on the life cycle of Subgraph,
140   // which is owned by an Interpreter. Therefore, the Interpreter will takes the
141   // responsibility to create and manage SignatureRunner objects to make sure
142   // SignatureRunner objects don't outlive their corresponding Subgraph objects.
143   SignatureRunner(const internal::SignatureDef* signature_def,
144                   Subgraph* subgraph);
145   friend class Interpreter;
146   friend class SignatureRunnerJNIHelper;
147   friend class TensorHandle;
148   friend class SignatureRunnerHelper;
149 
150   // The SignatureDef object is owned by the interpreter.
151   const internal::SignatureDef* signature_def_;
152   // The Subgraph object is owned by the interpreter.
153   Subgraph* subgraph_;
154   // The list of input tensor names.
155   std::vector<const char*> input_names_;
156   // The list of output tensor names.
157   std::vector<const char*> output_names_;
158 };
159 
160 }  // namespace tflite
161 
162 #endif  // TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
163