1 /* 2 * Copyright (c) 2017-2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #pragma once 25 26 #include "../asmlib.hpp" 27 28 template <unsigned int IntBy, typename TIn, typename TOut> 29 struct TransposeInterleaveCommon { 30 // Override the moveblock_1xY methods to improve performance moveblock_1x1TransposeInterleaveCommon31 static inline void moveblock_1x1(const TIn *&in0, TOut *out) { 32 for (unsigned int i = 0; i < IntBy; i++) { 33 *out++ = static_cast<TOut>(*in0++); 34 } 35 } 36 moveblock_1x2TransposeInterleaveCommon37 static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out) { 38 for (unsigned int i = 0; i < IntBy; i++) { 39 *out++ = static_cast<TOut>(*in0++); 40 } 41 for (unsigned int i = 0; i < IntBy; i++) { 42 *out++ = static_cast<TOut>(*in1++); 43 } 44 } 45 moveblock_1x4TransposeInterleaveCommon46 static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out) { 47 for (unsigned int i = 0; i < IntBy; i++) { 48 *out++ = static_cast<TOut>(*in0++); 49 } 50 for (unsigned int i = 0; i < IntBy; i++) { 51 *out++ = static_cast<TOut>(*in1++); 52 } 53 for (unsigned int i = 0; i < IntBy; i++) { 54 *out++ = static_cast<TOut>(*in2++); 55 } 56 for (unsigned int i = 0; i < IntBy; i++) { 57 *out++ = static_cast<TOut>(*in3++); 58 } 59 } 60 TransformTransposeInterleaveCommon61 static void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax) { 62 const auto ldin = stride; 63 64 TOut *outarray = out; 65 const TIn *inarray = in; 66 TOut *outptr_base = outarray; 67 const TIn *inptr_base = inarray + x0 + (k0 * ldin); 68 int ldout = (kmax - k0) * IntBy; 69 70 int k=(kmax-k0); 71 for ( ; k>3; k-=4) { 72 TOut *outptr = outptr_base; 73 const TIn *inptr = inptr_base; 74 const TIn *inptr1 = inptr + ldin; 75 const TIn *inptr2 = inptr1 + ldin; 76 const TIn *inptr3 = inptr2 + ldin; 77 78 prefetch_3x(inptr); 79 prefetch_3x(inptr1); 80 prefetch_3x(inptr2); 81 prefetch_3x(inptr3); 82 83 outptr_base += IntBy * 4; 84 inptr_base += ldin * 4; 85 86 for (int x = (xmax-x0) / IntBy; x > 0 ; x--) { 87 moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr); 88 outptr += ldout; 89 } 90 } 91 92 if (k) { 93 TOut *outptr = outptr_base; 94 const TIn *inptr = inptr_base; 95 const TIn *inptr1 = inptr + ldin; 96 const TIn *inptr2 = inptr1 + ldin; 97 98 prefetch_3x(inptr); 99 prefetch_3x(inptr1); 100 prefetch_3x(inptr2); 101 102 for (int x = (xmax-x0) / IntBy; x > 0 ; x--) { 103 switch(k) { 104 case 3: 105 moveblock_1x2(inptr, inptr1, outptr); 106 moveblock_1x1(inptr2, outptr + IntBy * 2); 107 break; 108 109 case 2: 110 moveblock_1x2(inptr, inptr1, outptr); 111 break; 112 113 case 1: 114 moveblock_1x1(inptr, outptr); 115 break; 116 117 default: 118 UNREACHABLE("Impossible."); 119 } 120 121 outptr += ldout; 122 } 123 } 124 125 // Cope with ragged X cases 126 const unsigned int overflow = (xmax - x0) % IntBy; 127 if (overflow) { 128 const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin); 129 TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout; 130 131 for (int k=(kmax-k0); k>0; k--) { 132 const TIn *inptr = inptr_base; 133 inptr_base += ldin; 134 135 for (unsigned int x=0; x < IntBy; x++) { 136 TOut val = (x < overflow) ? static_cast<TOut>(*inptr++) : static_cast<TOut>(0); 137 *outptr++ = val; 138 } 139 } 140 } 141 } 142 }; 143