xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "PreparedModel.h"
18 
19 #include "Burst.h"
20 
21 #include <android-base/logging.h>
22 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
23 #include <android/hardware/neuralnetworks/1.0/types.h>
24 #include <android/hardware/neuralnetworks/1.2/IBurstCallback.h>
25 #include <android/hardware/neuralnetworks/1.2/IExecutionCallback.h>
26 #include <android/hardware/neuralnetworks/1.2/types.h>
27 #include <android/hardware/neuralnetworks/1.3/IExecutionCallback.h>
28 #include <android/hardware/neuralnetworks/1.3/IFencedExecutionCallback.h>
29 #include <android/hardware/neuralnetworks/1.3/IPreparedModel.h>
30 #include <android/hardware/neuralnetworks/1.3/types.h>
31 #include <nnapi/IPreparedModel.h>
32 #include <nnapi/TypeUtils.h>
33 #include <nnapi/Types.h>
34 #include <nnapi/Validation.h>
35 #include <nnapi/hal/1.0/Utils.h>
36 #include <nnapi/hal/1.2/Utils.h>
37 #include <nnapi/hal/1.3/Conversions.h>
38 #include <nnapi/hal/1.3/Utils.h>
39 
40 #include <memory>
41 #include <thread>
42 
43 // See hardware/interfaces/neuralnetworks/utils/README.md for more information on HIDL interface
44 // lifetimes across processes and for protecting asynchronous calls across HIDL.
45 
46 namespace android::hardware::neuralnetworks::adapter {
47 namespace {
48 
49 template <typename Type>
convertInput(const Type & object)50 auto convertInput(const Type& object) -> decltype(nn::convert(std::declval<Type>())) {
51     auto result = nn::convert(object);
52     if (!result.has_value()) {
53         result.error().code = nn::ErrorStatus::INVALID_ARGUMENT;
54     }
55     return result;
56 }
57 
58 class FencedExecutionCallback final : public V1_3::IFencedExecutionCallback {
59   public:
FencedExecutionCallback(const nn::ExecuteFencedInfoCallback & callback)60     explicit FencedExecutionCallback(const nn::ExecuteFencedInfoCallback& callback)
61         : kCallback(callback) {
62         CHECK(callback != nullptr);
63     }
64 
getExecutionInfo(getExecutionInfo_cb cb)65     Return<void> getExecutionInfo(getExecutionInfo_cb cb) override {
66         const auto result = kCallback();
67         if (!result.has_value()) {
68             const auto& [message, code] = result.error();
69             const auto status =
70                     V1_3::utils::convert(code).value_or(V1_3::ErrorStatus::GENERAL_FAILURE);
71             LOG(ERROR) << message;
72             cb(status, V1_2::utils::kNoTiming, V1_2::utils::kNoTiming);
73             return Void();
74         }
75         const auto [timingLaunched, timingFenced] = result.value();
76         const auto hidlTimingLaunched = V1_3::utils::convert(timingLaunched).value();
77         const auto hidlTimingFenced = V1_3::utils::convert(timingFenced).value();
78         cb(V1_3::ErrorStatus::NONE, hidlTimingLaunched, hidlTimingFenced);
79         return Void();
80     }
81 
82   private:
83     const nn::ExecuteFencedInfoCallback kCallback;
84 };
85 
86 using ExecutionResult = nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>;
87 
notify(V1_0::IExecutionCallback * callback,nn::ErrorStatus status,const std::vector<nn::OutputShape> &,const nn::Timing &)88 void notify(V1_0::IExecutionCallback* callback, nn::ErrorStatus status,
89             const std::vector<nn::OutputShape>& /*outputShapes*/, const nn::Timing& /*timing*/) {
90     if (callback != nullptr) {
91         const auto hidlStatus = V1_0::utils::convert(status).value();
92         const auto ret = callback->notify(hidlStatus);
93         if (!ret.isOk()) {
94             LOG(ERROR) << "V1_0::IExecutionCallback::notify failed with " << ret.description();
95         }
96     }
97 }
98 
notify(V1_2::IExecutionCallback * callback,nn::ErrorStatus status,const std::vector<nn::OutputShape> & outputShapes,const nn::Timing & timing)99 void notify(V1_2::IExecutionCallback* callback, nn::ErrorStatus status,
100             const std::vector<nn::OutputShape>& outputShapes, const nn::Timing& timing) {
101     if (callback != nullptr) {
102         const auto hidlStatus = V1_2::utils::convert(status).value();
103         const auto hidlOutputShapes = V1_2::utils::convert(outputShapes).value();
104         const auto hidlTiming = V1_2::utils::convert(timing).value();
105         const auto ret = callback->notify_1_2(hidlStatus, hidlOutputShapes, hidlTiming);
106         if (!ret.isOk()) {
107             LOG(ERROR) << "V1_2::IExecutionCallback::notify_1_2 failed with " << ret.description();
108         }
109     }
110 }
111 
notify(V1_3::IExecutionCallback * callback,nn::ErrorStatus status,const std::vector<nn::OutputShape> & outputShapes,const nn::Timing & timing)112 void notify(V1_3::IExecutionCallback* callback, nn::ErrorStatus status,
113             const std::vector<nn::OutputShape>& outputShapes, const nn::Timing& timing) {
114     if (callback != nullptr) {
115         const auto hidlStatus = V1_3::utils::convert(status).value();
116         const auto hidlOutputShapes = V1_3::utils::convert(outputShapes).value();
117         const auto hidlTiming = V1_3::utils::convert(timing).value();
118         const auto ret = callback->notify_1_3(hidlStatus, hidlOutputShapes, hidlTiming);
119         if (!ret.isOk()) {
120             LOG(ERROR) << "V1_3::IExecutionCallback::notify_1_3 failed with " << ret.description();
121         }
122     }
123 }
124 
125 template <typename CallbackType>
notify(CallbackType * callback,ExecutionResult result)126 void notify(CallbackType* callback, ExecutionResult result) {
127     if (!result.has_value()) {
128         const auto [message, status, outputShapes] = std::move(result).error();
129         LOG(ERROR) << message;
130         notify(callback, status, outputShapes, {});
131     } else {
132         const auto [outputShapes, timing] = std::move(result).value();
133         notify(callback, nn::ErrorStatus::NONE, outputShapes, timing);
134     }
135 }
136 
execute(const nn::SharedPreparedModel & preparedModel,const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)137 nn::GeneralResult<void> execute(const nn::SharedPreparedModel& preparedModel,
138                                 const V1_0::Request& request,
139                                 const sp<V1_0::IExecutionCallback>& callback) {
140     if (callback.get() == nullptr) {
141         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
142     }
143 
144     const auto nnRequest = NN_TRY(convertInput(request));
145 
146     auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {}, {}, {});
147 
148     if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
149         const auto& [message, code, outputShapes] = result.error();
150         return nn::error(code) << message;
151     }
152 
153     notify(callback.get(), std::move(result));
154     return {};
155 }
156 
execute_1_2(const nn::SharedPreparedModel & preparedModel,const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)157 nn::GeneralResult<void> execute_1_2(const nn::SharedPreparedModel& preparedModel,
158                                     const V1_0::Request& request, V1_2::MeasureTiming measure,
159                                     const sp<V1_2::IExecutionCallback>& callback) {
160     if (callback.get() == nullptr) {
161         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
162     }
163 
164     const auto nnRequest = NN_TRY(convertInput(request));
165     const auto nnMeasure = NN_TRY(convertInput(measure));
166 
167     auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {});
168 
169     if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
170         const auto& [message, code, outputShapes] = result.error();
171         return nn::error(code) << message;
172     }
173 
174     notify(callback.get(), std::move(result));
175     return {};
176 }
177 
execute_1_3(const nn::SharedPreparedModel & preparedModel,const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)178 nn::GeneralResult<void> execute_1_3(const nn::SharedPreparedModel& preparedModel,
179                                     const V1_3::Request& request, V1_2::MeasureTiming measure,
180                                     const V1_3::OptionalTimePoint& deadline,
181                                     const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
182                                     const sp<V1_3::IExecutionCallback>& callback) {
183     if (callback.get() == nullptr) {
184         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
185     }
186 
187     const auto nnRequest = NN_TRY(convertInput(request));
188     const auto nnMeasure = NN_TRY(convertInput(measure));
189     const auto nnDeadline = NN_TRY(convertInput(deadline));
190     const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
191 
192     auto result =
193             preparedModel->execute(nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration, {}, {});
194 
195     if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
196         const auto& [message, code, outputShapes] = result.error();
197         return nn::error(code) << message;
198     }
199 
200     notify(callback.get(), std::move(result));
201     return {};
202 }
203 
executeSynchronously(const nn::SharedPreparedModel & preparedModel,const V1_0::Request & request,V1_2::MeasureTiming measure)204 nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> executeSynchronously(
205         const nn::SharedPreparedModel& preparedModel, const V1_0::Request& request,
206         V1_2::MeasureTiming measure) {
207     const auto nnRequest = NN_TRY(convertInput(request));
208     const auto nnMeasure = NN_TRY(convertInput(measure));
209 
210     const auto [outputShapes, timing] =
211             NN_TRY(preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {}));
212 
213     auto hidlOutputShapes = NN_TRY(V1_2::utils::convert(outputShapes));
214     const auto hidlTiming = NN_TRY(V1_2::utils::convert(timing));
215     return std::make_pair(std::move(hidlOutputShapes), hidlTiming);
216 }
217 
executeSynchronously_1_3(const nn::SharedPreparedModel & preparedModel,const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration)218 nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> executeSynchronously_1_3(
219         const nn::SharedPreparedModel& preparedModel, const V1_3::Request& request,
220         V1_2::MeasureTiming measure, const V1_3::OptionalTimePoint& deadline,
221         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration) {
222     const auto nnRequest = NN_TRY(convertInput(request));
223     const auto nnMeasure = NN_TRY(convertInput(measure));
224     const auto nnDeadline = NN_TRY(convertInput(deadline));
225     const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
226 
227     const auto [outputShapes, timing] = NN_TRY(preparedModel->execute(
228             nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration, {}, {}));
229 
230     auto hidlOutputShapes = NN_TRY(V1_3::utils::convert(outputShapes));
231     const auto hidlTiming = NN_TRY(V1_3::utils::convert(timing));
232     return std::make_pair(std::move(hidlOutputShapes), hidlTiming);
233 }
234 
convertSyncFences(const hidl_vec<hidl_handle> & handles)235 nn::GeneralResult<std::vector<nn::SyncFence>> convertSyncFences(
236         const hidl_vec<hidl_handle>& handles) {
237     auto nnHandles = NN_TRY(convertInput(handles));
238     std::vector<nn::SyncFence> syncFences;
239     syncFences.reserve(handles.size());
240     for (auto&& handle : nnHandles) {
241         if (auto syncFence = nn::SyncFence::create(std::move(handle)); !syncFence.ok()) {
242             return nn::error(nn::ErrorStatus::INVALID_ARGUMENT) << std::move(syncFence).error();
243         } else {
244             syncFences.push_back(std::move(syncFence).value());
245         }
246     }
247     return syncFences;
248 }
249 
configureExecutionBurst(const nn::SharedPreparedModel & preparedModel,const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel)250 nn::GeneralResult<sp<V1_2::IBurstContext>> configureExecutionBurst(
251         const nn::SharedPreparedModel& preparedModel, const sp<V1_2::IBurstCallback>& callback,
252         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
253         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel) {
254     auto burstExecutor = NN_TRY(preparedModel->configureExecutionBurst());
255     return Burst::create(callback, requestChannel, resultChannel, std::move(burstExecutor),
256                          V1_2::utils::getBurstServerPollingTimeWindow());
257 }
258 
executeFenced(const nn::SharedPreparedModel & preparedModel,const V1_3::Request & request,const hidl_vec<hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration)259 nn::GeneralResult<std::pair<hidl_handle, sp<V1_3::IFencedExecutionCallback>>> executeFenced(
260         const nn::SharedPreparedModel& preparedModel, const V1_3::Request& request,
261         const hidl_vec<hidl_handle>& waitFor, V1_2::MeasureTiming measure,
262         const V1_3::OptionalTimePoint& deadline,
263         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
264         const V1_3::OptionalTimeoutDuration& duration) {
265     const auto nnRequest = NN_TRY(convertInput(request));
266     const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor));
267     const auto nnMeasure = NN_TRY(convertInput(measure));
268     const auto nnDeadline = NN_TRY(convertInput(deadline));
269     const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
270     const auto nnDuration = NN_TRY(convertInput(duration));
271 
272     auto [syncFence, executeFencedCallback] =
273             NN_TRY(preparedModel->executeFenced(nnRequest, nnWaitFor, nnMeasure, nnDeadline,
274                                                 nnLoopTimeoutDuration, nnDuration, {}, {}));
275 
276     auto hidlSyncFence = NN_TRY(V1_3::utils::convert(syncFence.getSharedHandle()));
277     auto hidlExecuteFencedCallback = sp<FencedExecutionCallback>::make(executeFencedCallback);
278     return std::make_pair(std::move(hidlSyncFence), std::move(hidlExecuteFencedCallback));
279 }
280 
281 }  // namespace
282 
PreparedModel(nn::SharedPreparedModel preparedModel)283 PreparedModel::PreparedModel(nn::SharedPreparedModel preparedModel)
284     : kPreparedModel(std::move(preparedModel)) {
285     CHECK(kPreparedModel != nullptr);
286 }
287 
getUnderlyingPreparedModel() const288 nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
289     return kPreparedModel;
290 }
291 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)292 Return<V1_0::ErrorStatus> PreparedModel::execute(const V1_0::Request& request,
293                                                  const sp<V1_0::IExecutionCallback>& callback) {
294     auto result = adapter::execute(kPreparedModel, request, callback);
295     if (!result.has_value()) {
296         auto [message, code] = std::move(result).error();
297         LOG(ERROR) << "adapter::PreparedModel::execute failed with " << code << ": " << message;
298         notify(callback.get(), code, {}, {});
299         return V1_0::utils::convert(code).value();
300     }
301     return V1_0::ErrorStatus::NONE;
302 }
303 
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)304 Return<V1_0::ErrorStatus> PreparedModel::execute_1_2(const V1_0::Request& request,
305                                                      V1_2::MeasureTiming measure,
306                                                      const sp<V1_2::IExecutionCallback>& callback) {
307     auto result = adapter::execute_1_2(kPreparedModel, request, measure, callback);
308     if (!result.has_value()) {
309         auto [message, code] = std::move(result).error();
310         LOG(ERROR) << "adapter::PreparedModel::execute_1_2 failed with " << code << ": " << message;
311         notify(callback.get(), code, {}, {});
312         return V1_2::utils::convert(code).value();
313     }
314     return V1_0::ErrorStatus::NONE;
315 }
316 
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)317 Return<V1_3::ErrorStatus> PreparedModel::execute_1_3(
318         const V1_3::Request& request, V1_2::MeasureTiming measure,
319         const V1_3::OptionalTimePoint& deadline,
320         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
321         const sp<V1_3::IExecutionCallback>& callback) {
322     auto result = adapter::execute_1_3(kPreparedModel, request, measure, deadline,
323                                        loopTimeoutDuration, callback);
324     if (!result.has_value()) {
325         auto [message, code] = std::move(result).error();
326         LOG(ERROR) << "adapter::PreparedModel::execute_1_3 failed with " << code << ": " << message;
327         notify(callback.get(), code, {}, {});
328         return V1_3::utils::convert(code).value();
329     }
330     return V1_3::ErrorStatus::NONE;
331 }
332 
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)333 Return<void> PreparedModel::executeSynchronously(const V1_0::Request& request,
334                                                  V1_2::MeasureTiming measure,
335                                                  executeSynchronously_cb cb) {
336     auto result = adapter::executeSynchronously(kPreparedModel, request, measure);
337     if (!result.has_value()) {
338         auto [message, code, outputShapes] = std::move(result).error();
339         LOG(ERROR) << "adapter::PreparedModel::executeSynchronously failed with " << code << ": "
340                    << message;
341         cb(V1_2::utils::convert(code).value(), V1_2::utils::convert(outputShapes).value(),
342            V1_2::utils::kNoTiming);
343         return Void();
344     }
345     auto [outputShapes, timing] = std::move(result).value();
346     cb(V1_0::ErrorStatus::NONE, outputShapes, timing);
347     return Void();
348 }
349 
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)350 Return<void> PreparedModel::executeSynchronously_1_3(
351         const V1_3::Request& request, V1_2::MeasureTiming measure,
352         const V1_3::OptionalTimePoint& deadline,
353         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
354     auto result = adapter::executeSynchronously_1_3(kPreparedModel, request, measure, deadline,
355                                                     loopTimeoutDuration);
356     if (!result.has_value()) {
357         auto [message, code, outputShapes] = std::move(result).error();
358         LOG(ERROR) << "adapter::PreparedModel::executeSynchronously_1_3 failed with " << code
359                    << ": " << message;
360         cb(V1_3::utils::convert(code).value(), V1_3::utils::convert(outputShapes).value(),
361            V1_2::utils::kNoTiming);
362         return Void();
363     }
364     auto [outputShapes, timing] = std::move(result).value();
365     cb(V1_3::ErrorStatus::NONE, outputShapes, timing);
366     return Void();
367 }
368 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)369 Return<void> PreparedModel::configureExecutionBurst(
370         const sp<V1_2::IBurstCallback>& callback,
371         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
372         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
373         configureExecutionBurst_cb cb) {
374     auto result = adapter::configureExecutionBurst(kPreparedModel, callback, requestChannel,
375                                                    resultChannel);
376     if (!result.has_value()) {
377         auto [message, code] = std::move(result).error();
378         LOG(ERROR) << "adapter::PreparedModel::configureExecutionBurst failed with " << code << ": "
379                    << message;
380         cb(V1_2::utils::convert(code).value(), nullptr);
381         return Void();
382     }
383     const auto burstContext = std::move(result).value();
384     cb(V1_0::ErrorStatus::NONE, burstContext);
385     return Void();
386 }
387 
executeFenced(const V1_3::Request & request,const hidl_vec<hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb callback)388 Return<void> PreparedModel::executeFenced(const V1_3::Request& request,
389                                           const hidl_vec<hidl_handle>& waitFor,
390                                           V1_2::MeasureTiming measure,
391                                           const V1_3::OptionalTimePoint& deadline,
392                                           const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
393                                           const V1_3::OptionalTimeoutDuration& duration,
394                                           executeFenced_cb callback) {
395     auto result = adapter::executeFenced(kPreparedModel, request, waitFor, measure, deadline,
396                                          loopTimeoutDuration, duration);
397     if (!result.has_value()) {
398         auto [message, code] = std::move(result).error();
399         LOG(ERROR) << "adapter::PreparedModel::executeFenced failed with " << code << ": "
400                    << message;
401         callback(V1_3::utils::convert(code).value(), {}, nullptr);
402         return Void();
403     }
404     auto [syncFence, executeFencedCallback] = std::move(result).value();
405     callback(V1_3::ErrorStatus::NONE, syncFence, executeFencedCallback);
406     return Void();
407 }
408 
409 }  // namespace android::hardware::neuralnetworks::adapter
410