1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "PreparedModel.h"
18
19 #include "Burst.h"
20
21 #include <android-base/logging.h>
22 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
23 #include <android/hardware/neuralnetworks/1.0/types.h>
24 #include <android/hardware/neuralnetworks/1.2/IBurstCallback.h>
25 #include <android/hardware/neuralnetworks/1.2/IExecutionCallback.h>
26 #include <android/hardware/neuralnetworks/1.2/types.h>
27 #include <android/hardware/neuralnetworks/1.3/IExecutionCallback.h>
28 #include <android/hardware/neuralnetworks/1.3/IFencedExecutionCallback.h>
29 #include <android/hardware/neuralnetworks/1.3/IPreparedModel.h>
30 #include <android/hardware/neuralnetworks/1.3/types.h>
31 #include <nnapi/IPreparedModel.h>
32 #include <nnapi/TypeUtils.h>
33 #include <nnapi/Types.h>
34 #include <nnapi/Validation.h>
35 #include <nnapi/hal/1.0/Utils.h>
36 #include <nnapi/hal/1.2/Utils.h>
37 #include <nnapi/hal/1.3/Conversions.h>
38 #include <nnapi/hal/1.3/Utils.h>
39
40 #include <memory>
41 #include <thread>
42
43 // See hardware/interfaces/neuralnetworks/utils/README.md for more information on HIDL interface
44 // lifetimes across processes and for protecting asynchronous calls across HIDL.
45
46 namespace android::hardware::neuralnetworks::adapter {
47 namespace {
48
49 template <typename Type>
convertInput(const Type & object)50 auto convertInput(const Type& object) -> decltype(nn::convert(std::declval<Type>())) {
51 auto result = nn::convert(object);
52 if (!result.has_value()) {
53 result.error().code = nn::ErrorStatus::INVALID_ARGUMENT;
54 }
55 return result;
56 }
57
58 class FencedExecutionCallback final : public V1_3::IFencedExecutionCallback {
59 public:
FencedExecutionCallback(const nn::ExecuteFencedInfoCallback & callback)60 explicit FencedExecutionCallback(const nn::ExecuteFencedInfoCallback& callback)
61 : kCallback(callback) {
62 CHECK(callback != nullptr);
63 }
64
getExecutionInfo(getExecutionInfo_cb cb)65 Return<void> getExecutionInfo(getExecutionInfo_cb cb) override {
66 const auto result = kCallback();
67 if (!result.has_value()) {
68 const auto& [message, code] = result.error();
69 const auto status =
70 V1_3::utils::convert(code).value_or(V1_3::ErrorStatus::GENERAL_FAILURE);
71 LOG(ERROR) << message;
72 cb(status, V1_2::utils::kNoTiming, V1_2::utils::kNoTiming);
73 return Void();
74 }
75 const auto [timingLaunched, timingFenced] = result.value();
76 const auto hidlTimingLaunched = V1_3::utils::convert(timingLaunched).value();
77 const auto hidlTimingFenced = V1_3::utils::convert(timingFenced).value();
78 cb(V1_3::ErrorStatus::NONE, hidlTimingLaunched, hidlTimingFenced);
79 return Void();
80 }
81
82 private:
83 const nn::ExecuteFencedInfoCallback kCallback;
84 };
85
86 using ExecutionResult = nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>;
87
notify(V1_0::IExecutionCallback * callback,nn::ErrorStatus status,const std::vector<nn::OutputShape> &,const nn::Timing &)88 void notify(V1_0::IExecutionCallback* callback, nn::ErrorStatus status,
89 const std::vector<nn::OutputShape>& /*outputShapes*/, const nn::Timing& /*timing*/) {
90 if (callback != nullptr) {
91 const auto hidlStatus = V1_0::utils::convert(status).value();
92 const auto ret = callback->notify(hidlStatus);
93 if (!ret.isOk()) {
94 LOG(ERROR) << "V1_0::IExecutionCallback::notify failed with " << ret.description();
95 }
96 }
97 }
98
notify(V1_2::IExecutionCallback * callback,nn::ErrorStatus status,const std::vector<nn::OutputShape> & outputShapes,const nn::Timing & timing)99 void notify(V1_2::IExecutionCallback* callback, nn::ErrorStatus status,
100 const std::vector<nn::OutputShape>& outputShapes, const nn::Timing& timing) {
101 if (callback != nullptr) {
102 const auto hidlStatus = V1_2::utils::convert(status).value();
103 const auto hidlOutputShapes = V1_2::utils::convert(outputShapes).value();
104 const auto hidlTiming = V1_2::utils::convert(timing).value();
105 const auto ret = callback->notify_1_2(hidlStatus, hidlOutputShapes, hidlTiming);
106 if (!ret.isOk()) {
107 LOG(ERROR) << "V1_2::IExecutionCallback::notify_1_2 failed with " << ret.description();
108 }
109 }
110 }
111
notify(V1_3::IExecutionCallback * callback,nn::ErrorStatus status,const std::vector<nn::OutputShape> & outputShapes,const nn::Timing & timing)112 void notify(V1_3::IExecutionCallback* callback, nn::ErrorStatus status,
113 const std::vector<nn::OutputShape>& outputShapes, const nn::Timing& timing) {
114 if (callback != nullptr) {
115 const auto hidlStatus = V1_3::utils::convert(status).value();
116 const auto hidlOutputShapes = V1_3::utils::convert(outputShapes).value();
117 const auto hidlTiming = V1_3::utils::convert(timing).value();
118 const auto ret = callback->notify_1_3(hidlStatus, hidlOutputShapes, hidlTiming);
119 if (!ret.isOk()) {
120 LOG(ERROR) << "V1_3::IExecutionCallback::notify_1_3 failed with " << ret.description();
121 }
122 }
123 }
124
125 template <typename CallbackType>
notify(CallbackType * callback,ExecutionResult result)126 void notify(CallbackType* callback, ExecutionResult result) {
127 if (!result.has_value()) {
128 const auto [message, status, outputShapes] = std::move(result).error();
129 LOG(ERROR) << message;
130 notify(callback, status, outputShapes, {});
131 } else {
132 const auto [outputShapes, timing] = std::move(result).value();
133 notify(callback, nn::ErrorStatus::NONE, outputShapes, timing);
134 }
135 }
136
execute(const nn::SharedPreparedModel & preparedModel,const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)137 nn::GeneralResult<void> execute(const nn::SharedPreparedModel& preparedModel,
138 const V1_0::Request& request,
139 const sp<V1_0::IExecutionCallback>& callback) {
140 if (callback.get() == nullptr) {
141 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
142 }
143
144 const auto nnRequest = NN_TRY(convertInput(request));
145
146 auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {}, {}, {});
147
148 if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
149 const auto& [message, code, outputShapes] = result.error();
150 return nn::error(code) << message;
151 }
152
153 notify(callback.get(), std::move(result));
154 return {};
155 }
156
execute_1_2(const nn::SharedPreparedModel & preparedModel,const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)157 nn::GeneralResult<void> execute_1_2(const nn::SharedPreparedModel& preparedModel,
158 const V1_0::Request& request, V1_2::MeasureTiming measure,
159 const sp<V1_2::IExecutionCallback>& callback) {
160 if (callback.get() == nullptr) {
161 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
162 }
163
164 const auto nnRequest = NN_TRY(convertInput(request));
165 const auto nnMeasure = NN_TRY(convertInput(measure));
166
167 auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {});
168
169 if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
170 const auto& [message, code, outputShapes] = result.error();
171 return nn::error(code) << message;
172 }
173
174 notify(callback.get(), std::move(result));
175 return {};
176 }
177
execute_1_3(const nn::SharedPreparedModel & preparedModel,const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)178 nn::GeneralResult<void> execute_1_3(const nn::SharedPreparedModel& preparedModel,
179 const V1_3::Request& request, V1_2::MeasureTiming measure,
180 const V1_3::OptionalTimePoint& deadline,
181 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
182 const sp<V1_3::IExecutionCallback>& callback) {
183 if (callback.get() == nullptr) {
184 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
185 }
186
187 const auto nnRequest = NN_TRY(convertInput(request));
188 const auto nnMeasure = NN_TRY(convertInput(measure));
189 const auto nnDeadline = NN_TRY(convertInput(deadline));
190 const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
191
192 auto result =
193 preparedModel->execute(nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration, {}, {});
194
195 if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
196 const auto& [message, code, outputShapes] = result.error();
197 return nn::error(code) << message;
198 }
199
200 notify(callback.get(), std::move(result));
201 return {};
202 }
203
executeSynchronously(const nn::SharedPreparedModel & preparedModel,const V1_0::Request & request,V1_2::MeasureTiming measure)204 nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> executeSynchronously(
205 const nn::SharedPreparedModel& preparedModel, const V1_0::Request& request,
206 V1_2::MeasureTiming measure) {
207 const auto nnRequest = NN_TRY(convertInput(request));
208 const auto nnMeasure = NN_TRY(convertInput(measure));
209
210 const auto [outputShapes, timing] =
211 NN_TRY(preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {}));
212
213 auto hidlOutputShapes = NN_TRY(V1_2::utils::convert(outputShapes));
214 const auto hidlTiming = NN_TRY(V1_2::utils::convert(timing));
215 return std::make_pair(std::move(hidlOutputShapes), hidlTiming);
216 }
217
executeSynchronously_1_3(const nn::SharedPreparedModel & preparedModel,const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration)218 nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> executeSynchronously_1_3(
219 const nn::SharedPreparedModel& preparedModel, const V1_3::Request& request,
220 V1_2::MeasureTiming measure, const V1_3::OptionalTimePoint& deadline,
221 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration) {
222 const auto nnRequest = NN_TRY(convertInput(request));
223 const auto nnMeasure = NN_TRY(convertInput(measure));
224 const auto nnDeadline = NN_TRY(convertInput(deadline));
225 const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
226
227 const auto [outputShapes, timing] = NN_TRY(preparedModel->execute(
228 nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration, {}, {}));
229
230 auto hidlOutputShapes = NN_TRY(V1_3::utils::convert(outputShapes));
231 const auto hidlTiming = NN_TRY(V1_3::utils::convert(timing));
232 return std::make_pair(std::move(hidlOutputShapes), hidlTiming);
233 }
234
convertSyncFences(const hidl_vec<hidl_handle> & handles)235 nn::GeneralResult<std::vector<nn::SyncFence>> convertSyncFences(
236 const hidl_vec<hidl_handle>& handles) {
237 auto nnHandles = NN_TRY(convertInput(handles));
238 std::vector<nn::SyncFence> syncFences;
239 syncFences.reserve(handles.size());
240 for (auto&& handle : nnHandles) {
241 if (auto syncFence = nn::SyncFence::create(std::move(handle)); !syncFence.ok()) {
242 return nn::error(nn::ErrorStatus::INVALID_ARGUMENT) << std::move(syncFence).error();
243 } else {
244 syncFences.push_back(std::move(syncFence).value());
245 }
246 }
247 return syncFences;
248 }
249
configureExecutionBurst(const nn::SharedPreparedModel & preparedModel,const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel)250 nn::GeneralResult<sp<V1_2::IBurstContext>> configureExecutionBurst(
251 const nn::SharedPreparedModel& preparedModel, const sp<V1_2::IBurstCallback>& callback,
252 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
253 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel) {
254 auto burstExecutor = NN_TRY(preparedModel->configureExecutionBurst());
255 return Burst::create(callback, requestChannel, resultChannel, std::move(burstExecutor),
256 V1_2::utils::getBurstServerPollingTimeWindow());
257 }
258
executeFenced(const nn::SharedPreparedModel & preparedModel,const V1_3::Request & request,const hidl_vec<hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration)259 nn::GeneralResult<std::pair<hidl_handle, sp<V1_3::IFencedExecutionCallback>>> executeFenced(
260 const nn::SharedPreparedModel& preparedModel, const V1_3::Request& request,
261 const hidl_vec<hidl_handle>& waitFor, V1_2::MeasureTiming measure,
262 const V1_3::OptionalTimePoint& deadline,
263 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
264 const V1_3::OptionalTimeoutDuration& duration) {
265 const auto nnRequest = NN_TRY(convertInput(request));
266 const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor));
267 const auto nnMeasure = NN_TRY(convertInput(measure));
268 const auto nnDeadline = NN_TRY(convertInput(deadline));
269 const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
270 const auto nnDuration = NN_TRY(convertInput(duration));
271
272 auto [syncFence, executeFencedCallback] =
273 NN_TRY(preparedModel->executeFenced(nnRequest, nnWaitFor, nnMeasure, nnDeadline,
274 nnLoopTimeoutDuration, nnDuration, {}, {}));
275
276 auto hidlSyncFence = NN_TRY(V1_3::utils::convert(syncFence.getSharedHandle()));
277 auto hidlExecuteFencedCallback = sp<FencedExecutionCallback>::make(executeFencedCallback);
278 return std::make_pair(std::move(hidlSyncFence), std::move(hidlExecuteFencedCallback));
279 }
280
281 } // namespace
282
PreparedModel(nn::SharedPreparedModel preparedModel)283 PreparedModel::PreparedModel(nn::SharedPreparedModel preparedModel)
284 : kPreparedModel(std::move(preparedModel)) {
285 CHECK(kPreparedModel != nullptr);
286 }
287
getUnderlyingPreparedModel() const288 nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
289 return kPreparedModel;
290 }
291
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)292 Return<V1_0::ErrorStatus> PreparedModel::execute(const V1_0::Request& request,
293 const sp<V1_0::IExecutionCallback>& callback) {
294 auto result = adapter::execute(kPreparedModel, request, callback);
295 if (!result.has_value()) {
296 auto [message, code] = std::move(result).error();
297 LOG(ERROR) << "adapter::PreparedModel::execute failed with " << code << ": " << message;
298 notify(callback.get(), code, {}, {});
299 return V1_0::utils::convert(code).value();
300 }
301 return V1_0::ErrorStatus::NONE;
302 }
303
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)304 Return<V1_0::ErrorStatus> PreparedModel::execute_1_2(const V1_0::Request& request,
305 V1_2::MeasureTiming measure,
306 const sp<V1_2::IExecutionCallback>& callback) {
307 auto result = adapter::execute_1_2(kPreparedModel, request, measure, callback);
308 if (!result.has_value()) {
309 auto [message, code] = std::move(result).error();
310 LOG(ERROR) << "adapter::PreparedModel::execute_1_2 failed with " << code << ": " << message;
311 notify(callback.get(), code, {}, {});
312 return V1_2::utils::convert(code).value();
313 }
314 return V1_0::ErrorStatus::NONE;
315 }
316
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)317 Return<V1_3::ErrorStatus> PreparedModel::execute_1_3(
318 const V1_3::Request& request, V1_2::MeasureTiming measure,
319 const V1_3::OptionalTimePoint& deadline,
320 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
321 const sp<V1_3::IExecutionCallback>& callback) {
322 auto result = adapter::execute_1_3(kPreparedModel, request, measure, deadline,
323 loopTimeoutDuration, callback);
324 if (!result.has_value()) {
325 auto [message, code] = std::move(result).error();
326 LOG(ERROR) << "adapter::PreparedModel::execute_1_3 failed with " << code << ": " << message;
327 notify(callback.get(), code, {}, {});
328 return V1_3::utils::convert(code).value();
329 }
330 return V1_3::ErrorStatus::NONE;
331 }
332
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)333 Return<void> PreparedModel::executeSynchronously(const V1_0::Request& request,
334 V1_2::MeasureTiming measure,
335 executeSynchronously_cb cb) {
336 auto result = adapter::executeSynchronously(kPreparedModel, request, measure);
337 if (!result.has_value()) {
338 auto [message, code, outputShapes] = std::move(result).error();
339 LOG(ERROR) << "adapter::PreparedModel::executeSynchronously failed with " << code << ": "
340 << message;
341 cb(V1_2::utils::convert(code).value(), V1_2::utils::convert(outputShapes).value(),
342 V1_2::utils::kNoTiming);
343 return Void();
344 }
345 auto [outputShapes, timing] = std::move(result).value();
346 cb(V1_0::ErrorStatus::NONE, outputShapes, timing);
347 return Void();
348 }
349
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)350 Return<void> PreparedModel::executeSynchronously_1_3(
351 const V1_3::Request& request, V1_2::MeasureTiming measure,
352 const V1_3::OptionalTimePoint& deadline,
353 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
354 auto result = adapter::executeSynchronously_1_3(kPreparedModel, request, measure, deadline,
355 loopTimeoutDuration);
356 if (!result.has_value()) {
357 auto [message, code, outputShapes] = std::move(result).error();
358 LOG(ERROR) << "adapter::PreparedModel::executeSynchronously_1_3 failed with " << code
359 << ": " << message;
360 cb(V1_3::utils::convert(code).value(), V1_3::utils::convert(outputShapes).value(),
361 V1_2::utils::kNoTiming);
362 return Void();
363 }
364 auto [outputShapes, timing] = std::move(result).value();
365 cb(V1_3::ErrorStatus::NONE, outputShapes, timing);
366 return Void();
367 }
368
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)369 Return<void> PreparedModel::configureExecutionBurst(
370 const sp<V1_2::IBurstCallback>& callback,
371 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
372 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
373 configureExecutionBurst_cb cb) {
374 auto result = adapter::configureExecutionBurst(kPreparedModel, callback, requestChannel,
375 resultChannel);
376 if (!result.has_value()) {
377 auto [message, code] = std::move(result).error();
378 LOG(ERROR) << "adapter::PreparedModel::configureExecutionBurst failed with " << code << ": "
379 << message;
380 cb(V1_2::utils::convert(code).value(), nullptr);
381 return Void();
382 }
383 const auto burstContext = std::move(result).value();
384 cb(V1_0::ErrorStatus::NONE, burstContext);
385 return Void();
386 }
387
executeFenced(const V1_3::Request & request,const hidl_vec<hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb callback)388 Return<void> PreparedModel::executeFenced(const V1_3::Request& request,
389 const hidl_vec<hidl_handle>& waitFor,
390 V1_2::MeasureTiming measure,
391 const V1_3::OptionalTimePoint& deadline,
392 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
393 const V1_3::OptionalTimeoutDuration& duration,
394 executeFenced_cb callback) {
395 auto result = adapter::executeFenced(kPreparedModel, request, waitFor, measure, deadline,
396 loopTimeoutDuration, duration);
397 if (!result.has_value()) {
398 auto [message, code] = std::move(result).error();
399 LOG(ERROR) << "adapter::PreparedModel::executeFenced failed with " << code << ": "
400 << message;
401 callback(V1_3::utils::convert(code).value(), {}, nullptr);
402 return Void();
403 }
404 auto [syncFence, executeFencedCallback] = std::move(result).value();
405 callback(V1_3::ErrorStatus::NONE, syncFence, executeFencedCallback);
406 return Void();
407 }
408
409 } // namespace android::hardware::neuralnetworks::adapter
410