1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2019-2021 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #pragma once 25*c217d954SCole Faust 26*c217d954SCole Faust #include <algorithm> 27*c217d954SCole Faust #include <array> 28*c217d954SCole Faust #include <cassert> 29*c217d954SCole Faust #include <initializer_list> 30*c217d954SCole Faust 31*c217d954SCole Faust namespace arm_gemm 32*c217d954SCole Faust { 33*c217d954SCole Faust template <unsigned int D> 34*c217d954SCole Faust class NDRange 35*c217d954SCole Faust { 36*c217d954SCole Faust private: 37*c217d954SCole Faust std::array<unsigned int, D> m_sizes{}; 38*c217d954SCole Faust std::array<unsigned int, D> m_totalsizes{}; 39*c217d954SCole Faust 40*c217d954SCole Faust class NDRangeIterator 41*c217d954SCole Faust { 42*c217d954SCole Faust private: 43*c217d954SCole Faust const NDRange &m_parent; 44*c217d954SCole Faust unsigned int m_pos = 0; 45*c217d954SCole Faust unsigned int m_end = 0; 46*c217d954SCole Faust 47*c217d954SCole Faust public: NDRangeIterator(const NDRange & p,unsigned int s,unsigned int e)48*c217d954SCole Faust NDRangeIterator(const NDRange &p, unsigned int s, unsigned int e) 49*c217d954SCole Faust : m_parent(p), m_pos(s), m_end(e) 50*c217d954SCole Faust { 51*c217d954SCole Faust } 52*c217d954SCole Faust done() const53*c217d954SCole Faust bool done() const 54*c217d954SCole Faust { 55*c217d954SCole Faust return (m_pos >= m_end); 56*c217d954SCole Faust } 57*c217d954SCole Faust dim(unsigned int d) const58*c217d954SCole Faust unsigned int dim(unsigned int d) const 59*c217d954SCole Faust { 60*c217d954SCole Faust unsigned int r = m_pos; 61*c217d954SCole Faust 62*c217d954SCole Faust if(d < (D - 1)) 63*c217d954SCole Faust { 64*c217d954SCole Faust r %= m_parent.m_totalsizes[d]; 65*c217d954SCole Faust } 66*c217d954SCole Faust 67*c217d954SCole Faust if(d > 0) 68*c217d954SCole Faust { 69*c217d954SCole Faust r /= m_parent.m_totalsizes[d - 1]; 70*c217d954SCole Faust } 71*c217d954SCole Faust 72*c217d954SCole Faust return r; 73*c217d954SCole Faust } 74*c217d954SCole Faust next_dim0()75*c217d954SCole Faust bool next_dim0() 76*c217d954SCole Faust { 77*c217d954SCole Faust m_pos++; 78*c217d954SCole Faust 79*c217d954SCole Faust return !done(); 80*c217d954SCole Faust } 81*c217d954SCole Faust next_dim1()82*c217d954SCole Faust bool next_dim1() 83*c217d954SCole Faust { 84*c217d954SCole Faust m_pos += m_parent.m_sizes[0] - dim(0); 85*c217d954SCole Faust 86*c217d954SCole Faust return !done(); 87*c217d954SCole Faust } 88*c217d954SCole Faust dim0_max() const89*c217d954SCole Faust unsigned int dim0_max() const 90*c217d954SCole Faust { 91*c217d954SCole Faust unsigned int offset = std::min(m_end - m_pos, m_parent.m_sizes[0] - dim(0)); 92*c217d954SCole Faust 93*c217d954SCole Faust return dim(0) + offset; 94*c217d954SCole Faust } 95*c217d954SCole Faust }; 96*c217d954SCole Faust set_totalsizes()97*c217d954SCole Faust void set_totalsizes() 98*c217d954SCole Faust { 99*c217d954SCole Faust unsigned int t = 1; 100*c217d954SCole Faust 101*c217d954SCole Faust for(unsigned int i = 0; i < D; i++) 102*c217d954SCole Faust { 103*c217d954SCole Faust if(m_sizes[i] == 0) 104*c217d954SCole Faust { 105*c217d954SCole Faust m_sizes[i] = 1; 106*c217d954SCole Faust } 107*c217d954SCole Faust 108*c217d954SCole Faust t *= m_sizes[i]; 109*c217d954SCole Faust 110*c217d954SCole Faust m_totalsizes[i] = t; 111*c217d954SCole Faust } 112*c217d954SCole Faust } 113*c217d954SCole Faust 114*c217d954SCole Faust public: 115*c217d954SCole Faust NDRange &operator=(const NDRange &rhs) = default; 116*c217d954SCole Faust NDRange(const NDRange &rhs) = default; 117*c217d954SCole Faust 118*c217d954SCole Faust template <typename... T> NDRange(T...ts)119*c217d954SCole Faust NDRange(T... ts) 120*c217d954SCole Faust : m_sizes{ ts... } 121*c217d954SCole Faust { 122*c217d954SCole Faust set_totalsizes(); 123*c217d954SCole Faust } 124*c217d954SCole Faust NDRange(const std::array<unsigned int,D> & n)125*c217d954SCole Faust NDRange(const std::array<unsigned int, D> &n) 126*c217d954SCole Faust : m_sizes(n) 127*c217d954SCole Faust { 128*c217d954SCole Faust set_totalsizes(); 129*c217d954SCole Faust } 130*c217d954SCole Faust iterator(unsigned int start,unsigned int end) const131*c217d954SCole Faust NDRangeIterator iterator(unsigned int start, unsigned int end) const 132*c217d954SCole Faust { 133*c217d954SCole Faust return NDRangeIterator(*this, start, end); 134*c217d954SCole Faust } 135*c217d954SCole Faust total_size() const136*c217d954SCole Faust unsigned int total_size() const 137*c217d954SCole Faust { 138*c217d954SCole Faust return m_totalsizes[D - 1]; 139*c217d954SCole Faust } 140*c217d954SCole Faust get_size(unsigned int v) const141*c217d954SCole Faust unsigned int get_size(unsigned int v) const 142*c217d954SCole Faust { 143*c217d954SCole Faust return m_sizes[v]; 144*c217d954SCole Faust } 145*c217d954SCole Faust }; 146*c217d954SCole Faust 147*c217d954SCole Faust /** NDCoordinate builds upon a range, but specifies a starting position 148*c217d954SCole Faust * in addition to a size which it inherits from NDRange 149*c217d954SCole Faust */ 150*c217d954SCole Faust template <unsigned int N> 151*c217d954SCole Faust class NDCoordinate : public NDRange<N> 152*c217d954SCole Faust { 153*c217d954SCole Faust using int_t = unsigned int; 154*c217d954SCole Faust using ndrange_t = NDRange<N>; 155*c217d954SCole Faust 156*c217d954SCole Faust std::array<int_t, N> m_positions{}; 157*c217d954SCole Faust 158*c217d954SCole Faust public: 159*c217d954SCole Faust NDCoordinate &operator=(const NDCoordinate &rhs) = default; 160*c217d954SCole Faust NDCoordinate(const NDCoordinate &rhs) = default; NDCoordinate(const std::initializer_list<std::pair<int_t,int_t>> & list)161*c217d954SCole Faust NDCoordinate(const std::initializer_list<std::pair<int_t, int_t>> &list) 162*c217d954SCole Faust { 163*c217d954SCole Faust std::array<int_t, N> sizes{}; 164*c217d954SCole Faust 165*c217d954SCole Faust std::size_t i = 0; 166*c217d954SCole Faust for(auto &p : list) 167*c217d954SCole Faust { 168*c217d954SCole Faust m_positions[i] = p.first; 169*c217d954SCole Faust sizes[i++] = p.second; 170*c217d954SCole Faust } 171*c217d954SCole Faust 172*c217d954SCole Faust //update the parents sizes 173*c217d954SCole Faust static_cast<ndrange_t &>(*this) = ndrange_t(sizes); 174*c217d954SCole Faust } 175*c217d954SCole Faust get_position(int_t d) const176*c217d954SCole Faust int_t get_position(int_t d) const 177*c217d954SCole Faust { 178*c217d954SCole Faust assert(d < N); 179*c217d954SCole Faust 180*c217d954SCole Faust return m_positions[d]; 181*c217d954SCole Faust } 182*c217d954SCole Faust set_position(int_t d,int_t v)183*c217d954SCole Faust void set_position(int_t d, int_t v) 184*c217d954SCole Faust { 185*c217d954SCole Faust assert(d < N); 186*c217d954SCole Faust 187*c217d954SCole Faust m_positions[d] = v; 188*c217d954SCole Faust } 189*c217d954SCole Faust get_position_end(int_t d) const190*c217d954SCole Faust int_t get_position_end(int_t d) const 191*c217d954SCole Faust { 192*c217d954SCole Faust return get_position(d) + ndrange_t::get_size(d); 193*c217d954SCole Faust } 194*c217d954SCole Faust }; //class NDCoordinate 195*c217d954SCole Faust 196*c217d954SCole Faust using ndrange_t = NDRange<6>; 197*c217d954SCole Faust using ndcoord_t = NDCoordinate<6>; 198*c217d954SCole Faust 199*c217d954SCole Faust } // namespace arm_gemm 200