1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "CommonTestUtils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <vector>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker namespace
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnTypeInput>
CreateElementwiseBinaryNetwork(const TensorShape & input1Shape,const TensorShape & input2Shape,const TensorShape & outputShape,BinaryOperation operation,const float qScale=1.0f,const int32_t qOffset=0)22*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateElementwiseBinaryNetwork(const TensorShape& input1Shape,
23*89c4ff92SAndroid Build Coastguard Worker const TensorShape& input2Shape,
24*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape,
25*89c4ff92SAndroid Build Coastguard Worker BinaryOperation operation,
26*89c4ff92SAndroid Build Coastguard Worker const float qScale = 1.0f,
27*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = 0)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create());
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker TensorInfo input1TensorInfo(input1Shape, ArmnnTypeInput, qScale, qOffset, true);
34*89c4ff92SAndroid Build Coastguard Worker TensorInfo input2TensorInfo(input2Shape, ArmnnTypeInput, qScale, qOffset, true);
35*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, ArmnnTypeInput, qScale, qOffset);
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(armnn::numeric_cast<LayerBindingId>(0));
38*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(armnn::numeric_cast<LayerBindingId>(1));
39*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* elementwiseBinaryLayer = net->AddElementwiseBinaryLayer(operation, "elementwiseUnary");
40*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output");
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker Connect(input1, elementwiseBinaryLayer, input1TensorInfo, 0, 0);
43*89c4ff92SAndroid Build Coastguard Worker Connect(input2, elementwiseBinaryLayer, input2TensorInfo, 0, 1);
44*89c4ff92SAndroid Build Coastguard Worker Connect(elementwiseBinaryLayer, output, outputTensorInfo, 0, 0);
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker return net;
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnInType,
50*89c4ff92SAndroid Build Coastguard Worker typename TInput = armnn::ResolveType<ArmnnInType>>
ElementwiseBinarySimpleEndToEnd(const std::vector<BackendId> & backends,BinaryOperation operation)51*89c4ff92SAndroid Build Coastguard Worker void ElementwiseBinarySimpleEndToEnd(const std::vector<BackendId>& backends,
52*89c4ff92SAndroid Build Coastguard Worker BinaryOperation operation)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker const float qScale = IsQuantizedType<TInput>() ? 0.25f : 1.0f;
57*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = IsQuantizedType<TInput>() ? 50 : 0;
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker const TensorShape& input1Shape = { 2, 2, 2, 2 };
60*89c4ff92SAndroid Build Coastguard Worker const TensorShape& input2Shape = { 1 };
61*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = { 2, 2, 2, 2 };
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
64*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateElementwiseBinaryNetwork<ArmnnInType>(input1Shape, input2Shape, outputShape,
65*89c4ff92SAndroid Build Coastguard Worker operation, qScale, qOffset);
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> input1({ 1, -1, 1, 1, 5, -5, 5, 5, -3, 3, 3, 3, 4, 4, -4, 4 });
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> input2({ 2 });
72*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput;
73*89c4ff92SAndroid Build Coastguard Worker switch (operation) {
74*89c4ff92SAndroid Build Coastguard Worker case armnn::BinaryOperation::Add:
75*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { 3, 1, 3, 3, 7, -3, 7, 7, -1, 5, 5, 5, 6, 6, -2, 6 };
76*89c4ff92SAndroid Build Coastguard Worker break;
77*89c4ff92SAndroid Build Coastguard Worker case armnn::BinaryOperation::Div:
78*89c4ff92SAndroid Build Coastguard Worker expectedOutput = {0.5f, -0.5f, 0.5f, 0.5f, 2.5f, -2.5f, 2.5f, 2.5f, -1.5f, 1.5f, 1.5f, 1.5f, 2, 2, -2, 2};
79*89c4ff92SAndroid Build Coastguard Worker break;
80*89c4ff92SAndroid Build Coastguard Worker case armnn::BinaryOperation::Maximum:
81*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { 2, 2, 2, 2, 5, 2, 5, 5, 2, 3, 3, 3, 4, 4, 2, 4 };
82*89c4ff92SAndroid Build Coastguard Worker break;
83*89c4ff92SAndroid Build Coastguard Worker case armnn::BinaryOperation::Minimum:
84*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { 1, -1, 1, 1, 2, -5, 2, 2, -3, 2, 2, 2, 2, 2, -4, 2 };
85*89c4ff92SAndroid Build Coastguard Worker break;
86*89c4ff92SAndroid Build Coastguard Worker case armnn::BinaryOperation::Mul:
87*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { 2, -2, 2, 2, 10, -10, 10, 10, -6, 6, 6, 6, 8, 8, -8, 8 };
88*89c4ff92SAndroid Build Coastguard Worker break;
89*89c4ff92SAndroid Build Coastguard Worker case armnn::BinaryOperation::Sub:
90*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { -1, -3, -1, -1, 3, -7, 3, 3, -5, 1, 1, 1, 2, 2, -6, 2 };
91*89c4ff92SAndroid Build Coastguard Worker break;
92*89c4ff92SAndroid Build Coastguard Worker default:
93*89c4ff92SAndroid Build Coastguard Worker throw("Invalid Elementwise Binary operation");
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> expectedOutput_const = expectedOutput;
96*89c4ff92SAndroid Build Coastguard Worker // quantize data
97*89c4ff92SAndroid Build Coastguard Worker std::vector<TInput> qInput1Data = armnnUtils::QuantizedVector<TInput>(input1, qScale, qOffset);
98*89c4ff92SAndroid Build Coastguard Worker std::vector<TInput> qInput2Data = armnnUtils::QuantizedVector<TInput>(input2, qScale, qOffset);
99*89c4ff92SAndroid Build Coastguard Worker std::vector<TInput> qExpectedOutput = armnnUtils::QuantizedVector<TInput>(expectedOutput_const, qScale, qOffset);
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<TInput>> inputTensorData = {{ 0, qInput1Data }, { 1, qInput2Data }};
102*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<TInput>> expectedOutputData = {{ 0, qExpectedOutput }};
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnInType, ArmnnInType>(std::move(net), inputTensorData, expectedOutputData, backends);
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
108