xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/utils/adapter/aidl/src/PreparedModel.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "PreparedModel.h"
18 
19 #include "Burst.h"
20 #include "Execution.h"
21 
22 #include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
23 #include <aidl/android/hardware/neuralnetworks/BnPreparedModel.h>
24 #include <aidl/android/hardware/neuralnetworks/ExecutionResult.h>
25 #include <aidl/android/hardware/neuralnetworks/FencedExecutionResult.h>
26 #include <aidl/android/hardware/neuralnetworks/IBurst.h>
27 #include <aidl/android/hardware/neuralnetworks/Request.h>
28 #include <android-base/logging.h>
29 #include <android/binder_auto_utils.h>
30 #include <nnapi/IExecution.h>
31 #include <nnapi/IPreparedModel.h>
32 #include <nnapi/Result.h>
33 #include <nnapi/SharedMemory.h>
34 #include <nnapi/Types.h>
35 #include <nnapi/Validation.h>
36 #include <nnapi/hal/aidl/Conversions.h>
37 #include <nnapi/hal/aidl/Utils.h>
38 
39 #include <memory>
40 #include <utility>
41 #include <vector>
42 
43 namespace aidl::android::hardware::neuralnetworks::adapter {
44 namespace {
45 
46 class FencedExecutionCallback : public BnFencedExecutionCallback {
47   public:
FencedExecutionCallback(nn::ExecuteFencedInfoCallback callback)48     FencedExecutionCallback(nn::ExecuteFencedInfoCallback callback)
49         : kCallback(std::move(callback)) {}
50 
getExecutionInfo(Timing * timingLaunched,Timing * timingFenced,ErrorStatus * errorStatus)51     ndk::ScopedAStatus getExecutionInfo(Timing* timingLaunched, Timing* timingFenced,
52                                         ErrorStatus* errorStatus) override {
53         const auto result = kCallback();
54         if (result.ok()) {
55             const auto& [nnTimingLaunched, nnTimingFenced] = result.value();
56             *timingLaunched = utils::convert(nnTimingLaunched).value();
57             *timingFenced = utils::convert(nnTimingFenced).value();
58             *errorStatus = ErrorStatus::NONE;
59         } else {
60             constexpr auto kNoTiming = Timing{.timeOnDeviceNs = -1, .timeInDriverNs = -1};
61             const auto& [message, code] = result.error();
62             LOG(ERROR) << "getExecutionInfo failed with " << code << ": " << message;
63             const auto aidlStatus = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
64             *timingLaunched = kNoTiming;
65             *timingFenced = kNoTiming;
66             *errorStatus = aidlStatus;
67         }
68         return ndk::ScopedAStatus::ok();
69     }
70 
71   private:
72     const nn::ExecuteFencedInfoCallback kCallback;
73 };
74 
75 template <typename Type>
convertInput(const Type & object)76 auto convertInput(const Type& object) -> decltype(nn::convert(std::declval<Type>())) {
77     auto result = nn::convert(object);
78     if (!result.has_value()) {
79         result.error().code = nn::ErrorStatus::INVALID_ARGUMENT;
80     }
81     return result;
82 }
83 
convertSyncFences(const std::vector<ndk::ScopedFileDescriptor> & waitFor)84 nn::GeneralResult<std::vector<nn::SyncFence>> convertSyncFences(
85         const std::vector<ndk::ScopedFileDescriptor>& waitFor) {
86     auto handles = NN_TRY(convertInput(waitFor));
87 
88     constexpr auto valid = [](const nn::SharedHandle& handle) {
89         return handle != nullptr && handle->ok();
90     };
91     if (!std::all_of(handles.begin(), handles.end(), valid)) {
92         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid sync fence";
93     }
94 
95     std::vector<nn::SyncFence> syncFences;
96     syncFences.reserve(waitFor.size());
97     for (auto& handle : handles) {
98         syncFences.push_back(nn::SyncFence::create(std::move(handle)).value());
99     }
100     return syncFences;
101 }
102 
makeDuration(int64_t durationNs)103 nn::Duration makeDuration(int64_t durationNs) {
104     return nn::Duration(std::chrono::nanoseconds(durationNs));
105 }
106 
makeOptionalDuration(int64_t durationNs)107 nn::GeneralResult<nn::OptionalDuration> makeOptionalDuration(int64_t durationNs) {
108     if (durationNs < -1) {
109         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid duration " << durationNs;
110     }
111     return durationNs < 0 ? nn::OptionalDuration{} : makeDuration(durationNs);
112 }
113 
makeOptionalTimePoint(int64_t durationNs)114 nn::GeneralResult<nn::OptionalTimePoint> makeOptionalTimePoint(int64_t durationNs) {
115     if (durationNs < -1) {
116         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid time point " << durationNs;
117     }
118     return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs));
119 }
120 
executeSynchronously(const nn::IPreparedModel & preparedModel,const Request & request,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,const std::vector<TokenValuePair> & hints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix)121 nn::ExecutionResult<ExecutionResult> executeSynchronously(
122         const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming,
123         int64_t deadlineNs, int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
124         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
125     const auto nnRequest = NN_TRY(convertInput(request));
126     const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
127     const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
128     const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
129     auto nnHints = NN_TRY(convertInput(hints));
130     auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
131 
132     const auto result =
133             preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
134                                   nnHints, nnExtensionNameToPrefix);
135 
136     if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
137         const auto& [message, code, outputShapes] = result.error();
138         LOG(ERROR) << "executeSynchronously failed with " << code << ": " << message;
139         return ExecutionResult{.outputSufficientSize = false,
140                                .outputShapes = utils::convert(outputShapes).value(),
141                                .timing = {.timeInDriverNs = -1, .timeOnDeviceNs = -1}};
142     }
143 
144     const auto& [outputShapes, timing] = NN_TRY(result);
145     return ExecutionResult{.outputSufficientSize = true,
146                            .outputShapes = utils::convert(outputShapes).value(),
147                            .timing = utils::convert(timing).value()};
148 }
149 
executeFenced(const nn::IPreparedModel & preparedModel,const Request & request,const std::vector<ndk::ScopedFileDescriptor> & waitFor,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,int64_t durationNs,const std::vector<TokenValuePair> & hints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix)150 nn::GeneralResult<FencedExecutionResult> executeFenced(
151         const nn::IPreparedModel& preparedModel, const Request& request,
152         const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
153         int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
154         const std::vector<TokenValuePair>& hints,
155         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
156     const auto nnRequest = NN_TRY(convertInput(request));
157     const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor));
158     const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
159     const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
160     const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
161     const auto nnDuration = NN_TRY(makeOptionalDuration(durationNs));
162     auto nnHints = NN_TRY(convertInput(hints));
163     auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
164 
165     auto [syncFence, executeFencedInfoCallback] = NN_TRY(preparedModel.executeFenced(
166             nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration,
167             nnHints, nnExtensionNameToPrefix));
168 
169     ndk::ScopedFileDescriptor fileDescriptor;
170     if (syncFence.hasFd()) {
171         auto uniqueFd = NN_TRY(nn::dupFd(syncFence.getFd()));
172         fileDescriptor = ndk::ScopedFileDescriptor(uniqueFd.release());
173     }
174 
175     return FencedExecutionResult{.callback = ndk::SharedRefBase::make<FencedExecutionCallback>(
176                                          std::move(executeFencedInfoCallback)),
177                                  .syncFence = std::move(fileDescriptor)};
178 }
179 
createReusableExecution(const nn::IPreparedModel & preparedModel,const Request & request,bool measureTiming,int64_t loopTimeoutDurationNs,const std::vector<TokenValuePair> & hints,const std::vector<ExtensionNameAndPrefix> & extensionNameToPrefix)180 nn::GeneralResult<nn::SharedExecution> createReusableExecution(
181         const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming,
182         int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
183         const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
184     const auto nnRequest = NN_TRY(convertInput(request));
185     const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
186     const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
187     auto nnHints = NN_TRY(convertInput(hints));
188     auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
189 
190     return preparedModel.createReusableExecution(nnRequest, nnMeasureTiming, nnLoopTimeoutDuration,
191                                                  nnHints, nnExtensionNameToPrefix);
192 }
193 
executeSynchronously(const nn::IExecution & execution,int64_t deadlineNs)194 nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IExecution& execution,
195                                                           int64_t deadlineNs) {
196     const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
197 
198     const auto result = execution.compute(nnDeadline);
199 
200     if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
201         const auto& [message, code, outputShapes] = result.error();
202         LOG(ERROR) << "executeSynchronously failed with " << code << ": " << message;
203         return ExecutionResult{.outputSufficientSize = false,
204                                .outputShapes = utils::convert(outputShapes).value(),
205                                .timing = {.timeInDriverNs = -1, .timeOnDeviceNs = -1}};
206     }
207 
208     const auto& [outputShapes, timing] = NN_TRY(result);
209     return ExecutionResult{.outputSufficientSize = true,
210                            .outputShapes = utils::convert(outputShapes).value(),
211                            .timing = utils::convert(timing).value()};
212 }
213 
executeFenced(const nn::IExecution & execution,const std::vector<ndk::ScopedFileDescriptor> & waitFor,int64_t deadlineNs,int64_t durationNs)214 nn::GeneralResult<FencedExecutionResult> executeFenced(
215         const nn::IExecution& execution, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
216         int64_t deadlineNs, int64_t durationNs) {
217     const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor));
218     const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
219     const auto nnDuration = NN_TRY(makeOptionalDuration(durationNs));
220 
221     auto [syncFence, executeFencedInfoCallback] =
222             NN_TRY(execution.computeFenced(nnWaitFor, nnDeadline, nnDuration));
223 
224     ndk::ScopedFileDescriptor fileDescriptor;
225     if (syncFence.hasFd()) {
226         auto uniqueFd = NN_TRY(nn::dupFd(syncFence.getFd()));
227         fileDescriptor = ndk::ScopedFileDescriptor(uniqueFd.release());
228     }
229 
230     return FencedExecutionResult{.callback = ndk::SharedRefBase::make<FencedExecutionCallback>(
231                                          std::move(executeFencedInfoCallback)),
232                                  .syncFence = std::move(fileDescriptor)};
233 }
234 
235 }  // namespace
236 
PreparedModel(nn::SharedPreparedModel preparedModel)237 PreparedModel::PreparedModel(nn::SharedPreparedModel preparedModel)
238     : kPreparedModel(std::move(preparedModel)) {
239     CHECK(kPreparedModel != nullptr);
240 }
241 
executeSynchronously(const Request & request,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,ExecutionResult * executionResult)242 ndk::ScopedAStatus PreparedModel::executeSynchronously(const Request& request, bool measureTiming,
243                                                        int64_t deadlineNs,
244                                                        int64_t loopTimeoutDurationNs,
245                                                        ExecutionResult* executionResult) {
246     auto result = adapter::executeSynchronously(*kPreparedModel, request, measureTiming, deadlineNs,
247                                                 loopTimeoutDurationNs, {}, {});
248     if (!result.has_value()) {
249         const auto& [message, code, _] = result.error();
250         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
251         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
252                 static_cast<int32_t>(aidlCode), message.c_str());
253     }
254     *executionResult = std::move(result).value();
255     return ndk::ScopedAStatus::ok();
256 }
257 
executeFenced(const Request & request,const std::vector<ndk::ScopedFileDescriptor> & waitFor,bool measureTiming,int64_t deadlineNs,int64_t loopTimeoutDurationNs,int64_t durationNs,FencedExecutionResult * executionResult)258 ndk::ScopedAStatus PreparedModel::executeFenced(
259         const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
260         bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
261         FencedExecutionResult* executionResult) {
262     auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, measureTiming,
263                                          deadlineNs, loopTimeoutDurationNs, durationNs, {}, {});
264     if (!result.has_value()) {
265         const auto& [message, code] = result.error();
266         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
267         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
268                 static_cast<int32_t>(aidlCode), message.c_str());
269     }
270     *executionResult = std::move(result).value();
271     return ndk::ScopedAStatus::ok();
272 }
273 
executeSynchronouslyWithConfig(const Request & request,const ExecutionConfig & config,int64_t deadlineNs,ExecutionResult * executionResult)274 ndk::ScopedAStatus PreparedModel::executeSynchronouslyWithConfig(const Request& request,
275                                                                  const ExecutionConfig& config,
276                                                                  int64_t deadlineNs,
277                                                                  ExecutionResult* executionResult) {
278     auto result = adapter::executeSynchronously(
279             *kPreparedModel, request, config.measureTiming, deadlineNs,
280             config.loopTimeoutDurationNs, config.executionHints, config.extensionNameToPrefix);
281     if (!result.has_value()) {
282         const auto& [message, code, _] = result.error();
283         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
284         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
285                 static_cast<int32_t>(aidlCode), message.c_str());
286     }
287     *executionResult = std::move(result).value();
288     return ndk::ScopedAStatus::ok();
289 }
290 
executeFencedWithConfig(const Request & request,const std::vector<ndk::ScopedFileDescriptor> & waitFor,const ExecutionConfig & config,int64_t deadlineNs,int64_t durationNs,FencedExecutionResult * executionResult)291 ndk::ScopedAStatus PreparedModel::executeFencedWithConfig(
292         const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
293         const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
294         FencedExecutionResult* executionResult) {
295     auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, config.measureTiming,
296                                          deadlineNs, config.loopTimeoutDurationNs, durationNs,
297                                          config.executionHints, config.extensionNameToPrefix);
298     if (!result.has_value()) {
299         const auto& [message, code] = result.error();
300         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
301         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
302                 static_cast<int32_t>(aidlCode), message.c_str());
303     }
304     *executionResult = std::move(result).value();
305     return ndk::ScopedAStatus::ok();
306 }
307 
configureExecutionBurst(std::shared_ptr<IBurst> * burst)308 ndk::ScopedAStatus PreparedModel::configureExecutionBurst(std::shared_ptr<IBurst>* burst) {
309     auto result = kPreparedModel->configureExecutionBurst();
310     if (!result.has_value()) {
311         const auto& [message, code] = result.error();
312         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
313         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
314                 static_cast<int32_t>(aidlCode), message.c_str());
315     }
316     *burst = ndk::SharedRefBase::make<Burst>(std::move(result).value());
317     return ndk::ScopedAStatus::ok();
318 }
319 
getUnderlyingPreparedModel() const320 nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
321     return kPreparedModel;
322 }
323 
createReusableExecution(const Request & request,const ExecutionConfig & config,std::shared_ptr<IExecution> * execution)324 ndk::ScopedAStatus PreparedModel::createReusableExecution(const Request& request,
325                                                           const ExecutionConfig& config,
326                                                           std::shared_ptr<IExecution>* execution) {
327     auto result = adapter::createReusableExecution(
328             *kPreparedModel, request, config.measureTiming, config.loopTimeoutDurationNs,
329             config.executionHints, config.extensionNameToPrefix);
330     if (!result.has_value()) {
331         const auto& [message, code] = result.error();
332         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
333         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
334                 static_cast<int32_t>(aidlCode), message.c_str());
335     }
336     *execution = ndk::SharedRefBase::make<Execution>(std::move(result).value());
337     return ndk::ScopedAStatus::ok();
338 }
339 
Execution(nn::SharedExecution execution)340 Execution::Execution(nn::SharedExecution execution) : kExecution(std::move(execution)) {
341     CHECK(kExecution != nullptr);
342 }
343 
executeSynchronously(int64_t deadlineNs,ExecutionResult * executionResult)344 ndk::ScopedAStatus Execution::executeSynchronously(int64_t deadlineNs,
345                                                    ExecutionResult* executionResult) {
346     auto result = adapter::executeSynchronously(*kExecution, deadlineNs);
347     if (!result.has_value()) {
348         const auto& [message, code, _] = result.error();
349         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
350         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
351                 static_cast<int32_t>(aidlCode), message.c_str());
352     }
353     *executionResult = std::move(result).value();
354     return ndk::ScopedAStatus::ok();
355 }
356 
executeFenced(const std::vector<ndk::ScopedFileDescriptor> & waitFor,int64_t deadlineNs,int64_t durationNs,FencedExecutionResult * executionResult)357 ndk::ScopedAStatus Execution::executeFenced(const std::vector<ndk::ScopedFileDescriptor>& waitFor,
358                                             int64_t deadlineNs, int64_t durationNs,
359                                             FencedExecutionResult* executionResult) {
360     auto result = adapter::executeFenced(*kExecution, waitFor, deadlineNs, durationNs);
361     if (!result.has_value()) {
362         const auto& [message, code] = result.error();
363         const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
364         return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
365                 static_cast<int32_t>(aidlCode), message.c_str());
366     }
367     *executionResult = std::move(result).value();
368     return ndk::ScopedAStatus::ok();
369 }
370 
371 }  // namespace aidl::android::hardware::neuralnetworks::adapter
372