1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_SIGNATURE_RUNNER_H_ 16 #define TENSORFLOW_LITE_SIGNATURE_RUNNER_H_ 17 18 #include <cstddef> 19 #include <cstdint> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/core/subgraph.h" 25 #include "tensorflow/lite/internal/signature_def.h" 26 27 namespace tflite { 28 class Interpreter; // Class for friend declarations. 29 class SignatureRunnerJNIHelper; // Class for friend declarations. 30 class TensorHandle; // Class for friend declarations. 31 class SignatureRunnerHelper; // Class for friend declarations. 32 33 /// WARNING: Experimental interface, subject to change 34 /// 35 /// SignatureRunner class for running TFLite models using SignatureDef. 36 /// 37 /// Usage: 38 /// 39 /// <pre><code> 40 /// // Create model from file. Note that the model instance must outlive the 41 /// // interpreter instance. 42 /// auto model = tflite::FlatBufferModel::BuildFromFile(...); 43 /// if (model == nullptr) { 44 /// // Return error. 45 /// } 46 /// 47 /// // Create an Interpreter with an InterpreterBuilder. 48 /// std::unique_ptr<tflite::Interpreter> interpreter; 49 /// tflite::ops::builtin::BuiltinOpResolver resolver; 50 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { 51 /// // Return failure. 52 /// } 53 /// 54 /// // Get the list of signatures and check it. 55 /// auto signature_defs = interpreter->signature_def_names(); 56 /// if (signature_defs.empty()) { 57 /// // Return error. 58 /// } 59 /// 60 /// // Get pointer to the SignatureRunner instance corresponding to a signature. 61 /// // Note that the pointed SignatureRunner instance has lifetime same as the 62 /// // Interpreter instance. 63 /// tflite::SignatureRunner* runner = 64 /// interpreter->GetSignatureRunner(signature_defs[0]->c_str()); 65 /// if (runner == nullptr) { 66 /// // Return error. 67 /// } 68 /// if (runner->AllocateTensors() != kTfLiteOk) { 69 /// // Return failure. 70 /// } 71 /// 72 /// // Set input data. In this example, the input tensor has float type. 73 /// float* input = runner->input_tensor(0)->data.f; 74 /// for (int i = 0; i < input_size; i++) { 75 /// input[i] = ...; 76 // } 77 /// runner->Invoke(); 78 /// </code></pre> 79 /// 80 /// WARNING: This class is *not* thread-safe. The client is responsible for 81 /// ensuring serialized interaction to avoid data races and undefined behavior. 82 /// 83 /// SignatureRunner and Interpreter share the same underlying data. Calling 84 /// methods on an Interpreter object will affect the state in corresponding 85 /// SignatureRunner objects. Therefore, it is recommended not to call other 86 /// Interpreter methods after calling GetSignatureRunner to create 87 /// SignatureRunner instances. 88 class SignatureRunner { 89 public: 90 /// Returns the key for the corresponding signature. signature_key()91 const std::string& signature_key() { return signature_def_->signature_key; } 92 93 /// Returns the number of inputs. input_size()94 size_t input_size() const { return subgraph_->inputs().size(); } 95 96 /// Returns the number of outputs. output_size()97 size_t output_size() const { return subgraph_->outputs().size(); } 98 99 /// Read-only access to list of signature input names. input_names()100 const std::vector<const char*>& input_names() { return input_names_; } 101 102 /// Read-only access to list of signature output names. output_names()103 const std::vector<const char*>& output_names() { return output_names_; } 104 105 /// Returns the input tensor identified by 'input_name' in the 106 /// given signature. Returns nullptr if the given name is not valid. 107 TfLiteTensor* input_tensor(const char* input_name); 108 109 /// Returns the output tensor identified by 'output_name' in the 110 /// given signature. Returns nullptr if the given name is not valid. 111 const TfLiteTensor* output_tensor(const char* output_name) const; 112 113 /// Change a dimensionality of a given tensor. Note, this is only acceptable 114 /// for tensors that are inputs. 115 /// Returns status of failure or success. Note that this doesn't actually 116 /// resize any existing buffers. A call to AllocateTensors() is required to 117 /// change the tensor input buffer. 118 TfLiteStatus ResizeInputTensor(const char* input_name, 119 const std::vector<int>& new_size); 120 121 /// Change the dimensionality of a given tensor. This is only acceptable for 122 /// tensor indices that are inputs or variables. Only unknown dimensions can 123 /// be resized with this function. Unknown dimensions are indicated as `-1` in 124 /// the `dims_signature` attribute of a TfLiteTensor. 125 /// Returns status of failure or success. Note that this doesn't actually 126 /// resize any existing buffers. A call to AllocateTensors() is required to 127 /// change the tensor input buffer. 128 TfLiteStatus ResizeInputTensorStrict(const char* input_name, 129 const std::vector<int>& new_size); 130 131 /// Updates allocations for all tensors, related to the given signature. AllocateTensors()132 TfLiteStatus AllocateTensors() { return subgraph_->AllocateTensors(); } 133 134 /// Invokes the signature runner (run the graph identified by the given 135 /// signature in dependency order). 136 TfLiteStatus Invoke(); 137 138 private: 139 // The life cycle of SignatureRunner depends on the life cycle of Subgraph, 140 // which is owned by an Interpreter. Therefore, the Interpreter will takes the 141 // responsibility to create and manage SignatureRunner objects to make sure 142 // SignatureRunner objects don't outlive their corresponding Subgraph objects. 143 SignatureRunner(const internal::SignatureDef* signature_def, 144 Subgraph* subgraph); 145 friend class Interpreter; 146 friend class SignatureRunnerJNIHelper; 147 friend class TensorHandle; 148 friend class SignatureRunnerHelper; 149 150 // The SignatureDef object is owned by the interpreter. 151 const internal::SignatureDef* signature_def_; 152 // The Subgraph object is owned by the interpreter. 153 Subgraph* subgraph_; 154 // The list of input tensor names. 155 std::vector<const char*> input_names_; 156 // The list of output tensor names. 157 std::vector<const char*> output_names_; 158 }; 159 160 } // namespace tflite 161 162 #endif // TENSORFLOW_LITE_SIGNATURE_RUNNER_H_ 163