1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li
8*3e777be0SXin Li #include "DriverTestHelpers.hpp"
9*3e777be0SXin Li
10*3e777be0SXin Li #include <log/log.h>
11*3e777be0SXin Li
12*3e777be0SXin Li #include <OperationsUtils.h>
13*3e777be0SXin Li
14*3e777be0SXin Li using namespace android::hardware;
15*3e777be0SXin Li using namespace driverTestHelpers;
16*3e777be0SXin Li using namespace armnn_driver;
17*3e777be0SXin Li
18*3e777be0SXin Li using RequestArgument = V1_0::RequestArgument;
19*3e777be0SXin Li
20*3e777be0SXin Li namespace driverTestHelpers
21*3e777be0SXin Li {
22*3e777be0SXin Li #define ARMNN_ANDROID_FP16_TEST(result, fp16Expectation, fp32Expectation, fp16Enabled) \
23*3e777be0SXin Li if (fp16Enabled) \
24*3e777be0SXin Li { \
25*3e777be0SXin Li DOCTEST_CHECK_MESSAGE((result == fp16Expectation || result == fp32Expectation), result << \
26*3e777be0SXin Li " does not match either " << fp16Expectation << "[fp16] or " << fp32Expectation << "[fp32]"); \
27*3e777be0SXin Li } else \
28*3e777be0SXin Li { \
29*3e777be0SXin Li DOCTEST_CHECK(result == fp32Expectation); \
30*3e777be0SXin Li }
31*3e777be0SXin Li
32*3e777be0SXin Li void SetModelFp16Flag(V1_0::Model& model, bool fp16Enabled);
33*3e777be0SXin Li
34*3e777be0SXin Li void SetModelFp16Flag(V1_1::Model& model, bool fp16Enabled);
35*3e777be0SXin Li
36*3e777be0SXin Li template<typename HalPolicy>
PaddingTestImpl(android::nn::PaddingScheme paddingScheme,bool fp16Enabled=false)37*3e777be0SXin Li void PaddingTestImpl(android::nn::PaddingScheme paddingScheme, bool fp16Enabled = false)
38*3e777be0SXin Li {
39*3e777be0SXin Li using HalModel = typename HalPolicy::Model;
40*3e777be0SXin Li using HalOperationType = typename HalPolicy::OperationType;
41*3e777be0SXin Li
42*3e777be0SXin Li armnn::Compute computeDevice = armnn::Compute::GpuAcc;
43*3e777be0SXin Li
44*3e777be0SXin Li #ifndef ARMCOMPUTECL_ENABLED
45*3e777be0SXin Li computeDevice = armnn::Compute::CpuRef;
46*3e777be0SXin Li #endif
47*3e777be0SXin Li
48*3e777be0SXin Li auto driver = std::make_unique<ArmnnDriver>(DriverOptions(computeDevice, fp16Enabled));
49*3e777be0SXin Li HalModel model = {};
50*3e777be0SXin Li
51*3e777be0SXin Li uint32_t outSize = paddingScheme == android::nn::kPaddingSame ? 2 : 1;
52*3e777be0SXin Li
53*3e777be0SXin Li // add operands
54*3e777be0SXin Li float weightValue[] = {1.f, -1.f, 0.f, 1.f};
55*3e777be0SXin Li float biasValue[] = {0.f};
56*3e777be0SXin Li
57*3e777be0SXin Li AddInputOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 2, 3, 1});
58*3e777be0SXin Li AddTensorOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 2, 2, 1}, weightValue);
59*3e777be0SXin Li AddTensorOperand<HalPolicy>(model, hidl_vec < uint32_t > {1}, biasValue);
60*3e777be0SXin Li AddIntOperand<HalPolicy>(model, (int32_t) paddingScheme); // padding
61*3e777be0SXin Li AddIntOperand<HalPolicy>(model, 2); // stride x
62*3e777be0SXin Li AddIntOperand<HalPolicy>(model, 2); // stride y
63*3e777be0SXin Li AddIntOperand<HalPolicy>(model, 0); // no activation
64*3e777be0SXin Li AddOutputOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 1, outSize, 1});
65*3e777be0SXin Li
66*3e777be0SXin Li // make the convolution operation
67*3e777be0SXin Li model.operations.resize(1);
68*3e777be0SXin Li model.operations[0].type = HalOperationType::CONV_2D;
69*3e777be0SXin Li model.operations[0].inputs = hidl_vec < uint32_t > {0, 1, 2, 3, 4, 5, 6};
70*3e777be0SXin Li model.operations[0].outputs = hidl_vec < uint32_t > {7};
71*3e777be0SXin Li
72*3e777be0SXin Li // make the prepared model
73*3e777be0SXin Li SetModelFp16Flag(model, fp16Enabled);
74*3e777be0SXin Li android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
75*3e777be0SXin Li
76*3e777be0SXin Li // construct the request
77*3e777be0SXin Li V1_0::DataLocation inloc = {};
78*3e777be0SXin Li inloc.poolIndex = 0;
79*3e777be0SXin Li inloc.offset = 0;
80*3e777be0SXin Li inloc.length = 6 * sizeof(float);
81*3e777be0SXin Li RequestArgument input = {};
82*3e777be0SXin Li input.location = inloc;
83*3e777be0SXin Li input.dimensions = hidl_vec < uint32_t > {};
84*3e777be0SXin Li
85*3e777be0SXin Li V1_0::DataLocation outloc = {};
86*3e777be0SXin Li outloc.poolIndex = 1;
87*3e777be0SXin Li outloc.offset = 0;
88*3e777be0SXin Li outloc.length = outSize * sizeof(float);
89*3e777be0SXin Li RequestArgument output = {};
90*3e777be0SXin Li output.location = outloc;
91*3e777be0SXin Li output.dimensions = hidl_vec < uint32_t > {};
92*3e777be0SXin Li
93*3e777be0SXin Li V1_0::Request request = {};
94*3e777be0SXin Li request.inputs = hidl_vec < RequestArgument > {input};
95*3e777be0SXin Li request.outputs = hidl_vec < RequestArgument > {output};
96*3e777be0SXin Li
97*3e777be0SXin Li // set the input data (matching source test)
98*3e777be0SXin Li float indata[] = {1024.25f, 1.f, 0.f, 3.f, -1, -1024.25f};
99*3e777be0SXin Li AddPoolAndSetData(6, request, indata);
100*3e777be0SXin Li
101*3e777be0SXin Li // add memory for the output
102*3e777be0SXin Li android::sp<IMemory> outMemory = AddPoolAndGetData<float>(outSize, request);
103*3e777be0SXin Li float* outdata = reinterpret_cast<float*>(static_cast<void*>(outMemory->getPointer()));
104*3e777be0SXin Li
105*3e777be0SXin Li // run the execution
106*3e777be0SXin Li if (preparedModel.get() != nullptr)
107*3e777be0SXin Li {
108*3e777be0SXin Li Execute(preparedModel, request);
109*3e777be0SXin Li }
110*3e777be0SXin Li
111*3e777be0SXin Li // check the result
112*3e777be0SXin Li switch (paddingScheme)
113*3e777be0SXin Li {
114*3e777be0SXin Li case android::nn::kPaddingValid:
115*3e777be0SXin Li ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
116*3e777be0SXin Li break;
117*3e777be0SXin Li case android::nn::kPaddingSame:
118*3e777be0SXin Li ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
119*3e777be0SXin Li DOCTEST_CHECK(outdata[1] == 0.f);
120*3e777be0SXin Li break;
121*3e777be0SXin Li default:
122*3e777be0SXin Li DOCTEST_CHECK(false);
123*3e777be0SXin Li break;
124*3e777be0SXin Li }
125*3e777be0SXin Li }
126*3e777be0SXin Li
127*3e777be0SXin Li } // namespace driverTestHelpers
128