xref: /aosp_15_r20/external/android-nn-driver/test/Convolution2D.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include "DriverTestHelpers.hpp"
9*3e777be0SXin Li 
10*3e777be0SXin Li #include <log/log.h>
11*3e777be0SXin Li 
12*3e777be0SXin Li #include <OperationsUtils.h>
13*3e777be0SXin Li 
14*3e777be0SXin Li using namespace android::hardware;
15*3e777be0SXin Li using namespace driverTestHelpers;
16*3e777be0SXin Li using namespace armnn_driver;
17*3e777be0SXin Li 
18*3e777be0SXin Li using RequestArgument = V1_0::RequestArgument;
19*3e777be0SXin Li 
20*3e777be0SXin Li namespace driverTestHelpers
21*3e777be0SXin Li {
22*3e777be0SXin Li #define ARMNN_ANDROID_FP16_TEST(result, fp16Expectation, fp32Expectation, fp16Enabled) \
23*3e777be0SXin Li    if (fp16Enabled) \
24*3e777be0SXin Li    { \
25*3e777be0SXin Li        DOCTEST_CHECK_MESSAGE((result == fp16Expectation || result == fp32Expectation), result << \
26*3e777be0SXin Li        " does not match either " << fp16Expectation << "[fp16] or " << fp32Expectation << "[fp32]"); \
27*3e777be0SXin Li    } else \
28*3e777be0SXin Li    { \
29*3e777be0SXin Li       DOCTEST_CHECK(result == fp32Expectation); \
30*3e777be0SXin Li    }
31*3e777be0SXin Li 
32*3e777be0SXin Li void SetModelFp16Flag(V1_0::Model& model, bool fp16Enabled);
33*3e777be0SXin Li 
34*3e777be0SXin Li void SetModelFp16Flag(V1_1::Model& model, bool fp16Enabled);
35*3e777be0SXin Li 
36*3e777be0SXin Li template<typename HalPolicy>
PaddingTestImpl(android::nn::PaddingScheme paddingScheme,bool fp16Enabled=false)37*3e777be0SXin Li void PaddingTestImpl(android::nn::PaddingScheme paddingScheme, bool fp16Enabled = false)
38*3e777be0SXin Li {
39*3e777be0SXin Li     using HalModel         = typename HalPolicy::Model;
40*3e777be0SXin Li     using HalOperationType = typename HalPolicy::OperationType;
41*3e777be0SXin Li 
42*3e777be0SXin Li     armnn::Compute computeDevice = armnn::Compute::GpuAcc;
43*3e777be0SXin Li 
44*3e777be0SXin Li #ifndef ARMCOMPUTECL_ENABLED
45*3e777be0SXin Li     computeDevice = armnn::Compute::CpuRef;
46*3e777be0SXin Li #endif
47*3e777be0SXin Li 
48*3e777be0SXin Li     auto driver = std::make_unique<ArmnnDriver>(DriverOptions(computeDevice, fp16Enabled));
49*3e777be0SXin Li     HalModel model = {};
50*3e777be0SXin Li 
51*3e777be0SXin Li     uint32_t outSize = paddingScheme == android::nn::kPaddingSame ? 2 : 1;
52*3e777be0SXin Li 
53*3e777be0SXin Li     // add operands
54*3e777be0SXin Li     float weightValue[] = {1.f, -1.f, 0.f, 1.f};
55*3e777be0SXin Li     float biasValue[] = {0.f};
56*3e777be0SXin Li 
57*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 2, 3, 1});
58*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 2, 2, 1}, weightValue);
59*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec < uint32_t > {1}, biasValue);
60*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, (int32_t) paddingScheme); // padding
61*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, 2); // stride x
62*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, 2); // stride y
63*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, 0); // no activation
64*3e777be0SXin Li     AddOutputOperand<HalPolicy>(model, hidl_vec < uint32_t > {1, 1, outSize, 1});
65*3e777be0SXin Li 
66*3e777be0SXin Li     // make the convolution operation
67*3e777be0SXin Li     model.operations.resize(1);
68*3e777be0SXin Li     model.operations[0].type = HalOperationType::CONV_2D;
69*3e777be0SXin Li     model.operations[0].inputs = hidl_vec < uint32_t > {0, 1, 2, 3, 4, 5, 6};
70*3e777be0SXin Li     model.operations[0].outputs = hidl_vec < uint32_t > {7};
71*3e777be0SXin Li 
72*3e777be0SXin Li     // make the prepared model
73*3e777be0SXin Li     SetModelFp16Flag(model, fp16Enabled);
74*3e777be0SXin Li     android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
75*3e777be0SXin Li 
76*3e777be0SXin Li     // construct the request
77*3e777be0SXin Li     V1_0::DataLocation inloc = {};
78*3e777be0SXin Li     inloc.poolIndex = 0;
79*3e777be0SXin Li     inloc.offset = 0;
80*3e777be0SXin Li     inloc.length = 6 * sizeof(float);
81*3e777be0SXin Li     RequestArgument input = {};
82*3e777be0SXin Li     input.location = inloc;
83*3e777be0SXin Li     input.dimensions = hidl_vec < uint32_t > {};
84*3e777be0SXin Li 
85*3e777be0SXin Li     V1_0::DataLocation outloc = {};
86*3e777be0SXin Li     outloc.poolIndex = 1;
87*3e777be0SXin Li     outloc.offset = 0;
88*3e777be0SXin Li     outloc.length = outSize * sizeof(float);
89*3e777be0SXin Li     RequestArgument output = {};
90*3e777be0SXin Li     output.location = outloc;
91*3e777be0SXin Li     output.dimensions = hidl_vec < uint32_t > {};
92*3e777be0SXin Li 
93*3e777be0SXin Li     V1_0::Request request = {};
94*3e777be0SXin Li     request.inputs = hidl_vec < RequestArgument > {input};
95*3e777be0SXin Li     request.outputs = hidl_vec < RequestArgument > {output};
96*3e777be0SXin Li 
97*3e777be0SXin Li     // set the input data (matching source test)
98*3e777be0SXin Li     float indata[] = {1024.25f, 1.f, 0.f, 3.f, -1, -1024.25f};
99*3e777be0SXin Li     AddPoolAndSetData(6, request, indata);
100*3e777be0SXin Li 
101*3e777be0SXin Li     // add memory for the output
102*3e777be0SXin Li     android::sp<IMemory> outMemory = AddPoolAndGetData<float>(outSize, request);
103*3e777be0SXin Li     float* outdata = reinterpret_cast<float*>(static_cast<void*>(outMemory->getPointer()));
104*3e777be0SXin Li 
105*3e777be0SXin Li     // run the execution
106*3e777be0SXin Li     if (preparedModel.get() != nullptr)
107*3e777be0SXin Li     {
108*3e777be0SXin Li         Execute(preparedModel, request);
109*3e777be0SXin Li     }
110*3e777be0SXin Li 
111*3e777be0SXin Li     // check the result
112*3e777be0SXin Li     switch (paddingScheme)
113*3e777be0SXin Li     {
114*3e777be0SXin Li         case android::nn::kPaddingValid:
115*3e777be0SXin Li             ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
116*3e777be0SXin Li             break;
117*3e777be0SXin Li         case android::nn::kPaddingSame:
118*3e777be0SXin Li             ARMNN_ANDROID_FP16_TEST(outdata[0], 1022.f, 1022.25f, fp16Enabled)
119*3e777be0SXin Li             DOCTEST_CHECK(outdata[1] == 0.f);
120*3e777be0SXin Li             break;
121*3e777be0SXin Li         default:
122*3e777be0SXin Li             DOCTEST_CHECK(false);
123*3e777be0SXin Li             break;
124*3e777be0SXin Li     }
125*3e777be0SXin Li }
126*3e777be0SXin Li 
127*3e777be0SXin Li } // namespace driverTestHelpers
128