xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/assembly/ndrange.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2019-2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #pragma once
25*c217d954SCole Faust 
26*c217d954SCole Faust #include <algorithm>
27*c217d954SCole Faust #include <array>
28*c217d954SCole Faust #include <cassert>
29*c217d954SCole Faust #include <initializer_list>
30*c217d954SCole Faust 
31*c217d954SCole Faust namespace arm_gemm
32*c217d954SCole Faust {
33*c217d954SCole Faust template <unsigned int D>
34*c217d954SCole Faust class NDRange
35*c217d954SCole Faust {
36*c217d954SCole Faust private:
37*c217d954SCole Faust     std::array<unsigned int, D> m_sizes{};
38*c217d954SCole Faust     std::array<unsigned int, D> m_totalsizes{};
39*c217d954SCole Faust 
40*c217d954SCole Faust     class NDRangeIterator
41*c217d954SCole Faust     {
42*c217d954SCole Faust     private:
43*c217d954SCole Faust         const NDRange &m_parent;
44*c217d954SCole Faust         unsigned int   m_pos = 0;
45*c217d954SCole Faust         unsigned int   m_end = 0;
46*c217d954SCole Faust 
47*c217d954SCole Faust     public:
NDRangeIterator(const NDRange & p,unsigned int s,unsigned int e)48*c217d954SCole Faust         NDRangeIterator(const NDRange &p, unsigned int s, unsigned int e)
49*c217d954SCole Faust             : m_parent(p), m_pos(s), m_end(e)
50*c217d954SCole Faust         {
51*c217d954SCole Faust         }
52*c217d954SCole Faust 
done() const53*c217d954SCole Faust         bool done() const
54*c217d954SCole Faust         {
55*c217d954SCole Faust             return (m_pos >= m_end);
56*c217d954SCole Faust         }
57*c217d954SCole Faust 
dim(unsigned int d) const58*c217d954SCole Faust         unsigned int dim(unsigned int d) const
59*c217d954SCole Faust         {
60*c217d954SCole Faust             unsigned int r = m_pos;
61*c217d954SCole Faust 
62*c217d954SCole Faust             if(d < (D - 1))
63*c217d954SCole Faust             {
64*c217d954SCole Faust                 r %= m_parent.m_totalsizes[d];
65*c217d954SCole Faust             }
66*c217d954SCole Faust 
67*c217d954SCole Faust             if(d > 0)
68*c217d954SCole Faust             {
69*c217d954SCole Faust                 r /= m_parent.m_totalsizes[d - 1];
70*c217d954SCole Faust             }
71*c217d954SCole Faust 
72*c217d954SCole Faust             return r;
73*c217d954SCole Faust         }
74*c217d954SCole Faust 
next_dim0()75*c217d954SCole Faust         bool next_dim0()
76*c217d954SCole Faust         {
77*c217d954SCole Faust             m_pos++;
78*c217d954SCole Faust 
79*c217d954SCole Faust             return !done();
80*c217d954SCole Faust         }
81*c217d954SCole Faust 
next_dim1()82*c217d954SCole Faust         bool next_dim1()
83*c217d954SCole Faust         {
84*c217d954SCole Faust             m_pos += m_parent.m_sizes[0] - dim(0);
85*c217d954SCole Faust 
86*c217d954SCole Faust             return !done();
87*c217d954SCole Faust         }
88*c217d954SCole Faust 
dim0_max() const89*c217d954SCole Faust         unsigned int dim0_max() const
90*c217d954SCole Faust         {
91*c217d954SCole Faust             unsigned int offset = std::min(m_end - m_pos, m_parent.m_sizes[0] - dim(0));
92*c217d954SCole Faust 
93*c217d954SCole Faust             return dim(0) + offset;
94*c217d954SCole Faust         }
95*c217d954SCole Faust     };
96*c217d954SCole Faust 
set_totalsizes()97*c217d954SCole Faust     void set_totalsizes()
98*c217d954SCole Faust     {
99*c217d954SCole Faust         unsigned int t = 1;
100*c217d954SCole Faust 
101*c217d954SCole Faust         for(unsigned int i = 0; i < D; i++)
102*c217d954SCole Faust         {
103*c217d954SCole Faust             if(m_sizes[i] == 0)
104*c217d954SCole Faust             {
105*c217d954SCole Faust                 m_sizes[i] = 1;
106*c217d954SCole Faust             }
107*c217d954SCole Faust 
108*c217d954SCole Faust             t *= m_sizes[i];
109*c217d954SCole Faust 
110*c217d954SCole Faust             m_totalsizes[i] = t;
111*c217d954SCole Faust         }
112*c217d954SCole Faust     }
113*c217d954SCole Faust 
114*c217d954SCole Faust public:
115*c217d954SCole Faust     NDRange &operator=(const NDRange &rhs) = default;
116*c217d954SCole Faust     NDRange(const NDRange &rhs)            = default;
117*c217d954SCole Faust 
118*c217d954SCole Faust     template <typename... T>
NDRange(T...ts)119*c217d954SCole Faust     NDRange(T... ts)
120*c217d954SCole Faust         : m_sizes{ ts... }
121*c217d954SCole Faust     {
122*c217d954SCole Faust         set_totalsizes();
123*c217d954SCole Faust     }
124*c217d954SCole Faust 
NDRange(const std::array<unsigned int,D> & n)125*c217d954SCole Faust     NDRange(const std::array<unsigned int, D> &n)
126*c217d954SCole Faust         : m_sizes(n)
127*c217d954SCole Faust     {
128*c217d954SCole Faust         set_totalsizes();
129*c217d954SCole Faust     }
130*c217d954SCole Faust 
iterator(unsigned int start,unsigned int end) const131*c217d954SCole Faust     NDRangeIterator iterator(unsigned int start, unsigned int end) const
132*c217d954SCole Faust     {
133*c217d954SCole Faust         return NDRangeIterator(*this, start, end);
134*c217d954SCole Faust     }
135*c217d954SCole Faust 
total_size() const136*c217d954SCole Faust     unsigned int total_size() const
137*c217d954SCole Faust     {
138*c217d954SCole Faust         return m_totalsizes[D - 1];
139*c217d954SCole Faust     }
140*c217d954SCole Faust 
get_size(unsigned int v) const141*c217d954SCole Faust     unsigned int get_size(unsigned int v) const
142*c217d954SCole Faust     {
143*c217d954SCole Faust         return m_sizes[v];
144*c217d954SCole Faust     }
145*c217d954SCole Faust };
146*c217d954SCole Faust 
147*c217d954SCole Faust /** NDCoordinate builds upon a range, but specifies a starting position
148*c217d954SCole Faust  * in addition to a size which it inherits from NDRange
149*c217d954SCole Faust  */
150*c217d954SCole Faust template <unsigned int N>
151*c217d954SCole Faust class NDCoordinate : public NDRange<N>
152*c217d954SCole Faust {
153*c217d954SCole Faust     using int_t     = unsigned int;
154*c217d954SCole Faust     using ndrange_t = NDRange<N>;
155*c217d954SCole Faust 
156*c217d954SCole Faust     std::array<int_t, N> m_positions{};
157*c217d954SCole Faust 
158*c217d954SCole Faust public:
159*c217d954SCole Faust     NDCoordinate &operator=(const NDCoordinate &rhs) = default;
160*c217d954SCole Faust     NDCoordinate(const NDCoordinate &rhs)            = default;
NDCoordinate(const std::initializer_list<std::pair<int_t,int_t>> & list)161*c217d954SCole Faust     NDCoordinate(const std::initializer_list<std::pair<int_t, int_t>> &list)
162*c217d954SCole Faust     {
163*c217d954SCole Faust         std::array<int_t, N> sizes{};
164*c217d954SCole Faust 
165*c217d954SCole Faust         std::size_t i = 0;
166*c217d954SCole Faust         for(auto &p : list)
167*c217d954SCole Faust         {
168*c217d954SCole Faust             m_positions[i] = p.first;
169*c217d954SCole Faust             sizes[i++]     = p.second;
170*c217d954SCole Faust         }
171*c217d954SCole Faust 
172*c217d954SCole Faust         //update the parents sizes
173*c217d954SCole Faust         static_cast<ndrange_t &>(*this) = ndrange_t(sizes);
174*c217d954SCole Faust     }
175*c217d954SCole Faust 
get_position(int_t d) const176*c217d954SCole Faust     int_t get_position(int_t d) const
177*c217d954SCole Faust     {
178*c217d954SCole Faust         assert(d < N);
179*c217d954SCole Faust 
180*c217d954SCole Faust         return m_positions[d];
181*c217d954SCole Faust     }
182*c217d954SCole Faust 
set_position(int_t d,int_t v)183*c217d954SCole Faust     void set_position(int_t d, int_t v)
184*c217d954SCole Faust     {
185*c217d954SCole Faust         assert(d < N);
186*c217d954SCole Faust 
187*c217d954SCole Faust         m_positions[d] = v;
188*c217d954SCole Faust     }
189*c217d954SCole Faust 
get_position_end(int_t d) const190*c217d954SCole Faust     int_t get_position_end(int_t d) const
191*c217d954SCole Faust     {
192*c217d954SCole Faust         return get_position(d) + ndrange_t::get_size(d);
193*c217d954SCole Faust     }
194*c217d954SCole Faust }; //class NDCoordinate
195*c217d954SCole Faust 
196*c217d954SCole Faust using ndrange_t = NDRange<6>;
197*c217d954SCole Faust using ndcoord_t = NDCoordinate<6>;
198*c217d954SCole Faust 
199*c217d954SCole Faust } // namespace arm_gemm
200