xref: /aosp_15_r20/external/android-nn-driver/test/TestTensor.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include <ArmnnDriver.hpp>
9*3e777be0SXin Li #include "DriverTestHelpers.hpp"
10*3e777be0SXin Li 
11*3e777be0SXin Li namespace driverTestHelpers
12*3e777be0SXin Li {
13*3e777be0SXin Li 
14*3e777be0SXin Li class TestTensor
15*3e777be0SXin Li {
16*3e777be0SXin Li public:
TestTensor(const armnn::TensorShape & shape,const std::vector<float> & data)17*3e777be0SXin Li     TestTensor(const armnn::TensorShape & shape,
18*3e777be0SXin Li                const std::vector<float> & data)
19*3e777be0SXin Li     : m_Shape{shape}
20*3e777be0SXin Li     , m_Data{data}
21*3e777be0SXin Li     {
22*3e777be0SXin Li         DOCTEST_CHECK(m_Shape.GetNumElements() == m_Data.size());
23*3e777be0SXin Li     }
24*3e777be0SXin Li 
25*3e777be0SXin Li     hidl_vec<uint32_t> GetDimensions() const;
26*3e777be0SXin Li     unsigned int GetNumElements() const;
27*3e777be0SXin Li     const float * GetData() const;
28*3e777be0SXin Li 
29*3e777be0SXin Li private:
30*3e777be0SXin Li     armnn::TensorShape   m_Shape;
31*3e777be0SXin Li     std::vector<float>   m_Data;
32*3e777be0SXin Li };
33*3e777be0SXin Li 
34*3e777be0SXin Li } // driverTestHelpers
35