xref: /aosp_15_r20/external/android-nn-driver/test/TestHalfTensor.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include <ArmnnDriver.hpp>
9*3e777be0SXin Li #include "DriverTestHelpers.hpp"
10*3e777be0SXin Li 
11*3e777be0SXin Li #include <half/half.hpp>
12*3e777be0SXin Li 
13*3e777be0SXin Li using Half = half_float::half;
14*3e777be0SXin Li 
15*3e777be0SXin Li namespace driverTestHelpers
16*3e777be0SXin Li {
17*3e777be0SXin Li 
18*3e777be0SXin Li class TestHalfTensor
19*3e777be0SXin Li {
20*3e777be0SXin Li public:
TestHalfTensor(const armnn::TensorShape & shape,const std::vector<Half> & data)21*3e777be0SXin Li     TestHalfTensor(const armnn::TensorShape & shape,
22*3e777be0SXin Li                const std::vector<Half> & data)
23*3e777be0SXin Li         : m_Shape{shape}
24*3e777be0SXin Li         , m_Data{data}
25*3e777be0SXin Li     {
26*3e777be0SXin Li         DOCTEST_CHECK(m_Shape.GetNumElements() == m_Data.size());
27*3e777be0SXin Li     }
28*3e777be0SXin Li 
29*3e777be0SXin Li     hidl_vec<uint32_t> GetDimensions() const;
30*3e777be0SXin Li     unsigned int GetNumElements() const;
31*3e777be0SXin Li     const Half * GetData() const;
32*3e777be0SXin Li 
33*3e777be0SXin Li private:
34*3e777be0SXin Li     armnn::TensorShape   m_Shape;
35*3e777be0SXin Li     std::vector<Half>   m_Data;
36*3e777be0SXin Li };
37*3e777be0SXin Li 
38*3e777be0SXin Li } // driverTestHelpers
39