1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 18 19 #include <stddef.h> 20 #include <stdint.h> 21 22 #include <array> 23 #include <functional> 24 #include <numeric> 25 #include <string> 26 #include <utility> 27 #include <vector> 28 29 namespace tflite { 30 namespace gpu { 31 32 enum class Axis { 33 UNKNOWN = 0, 34 CHANNELS = 1, 35 INPUT_CHANNELS = 2, 36 OUTPUT_CHANNELS = 3, 37 HEIGHT = 4, 38 WIDTH = 5, 39 BATCH = 6, 40 VALUE = 7, 41 DEPTH = 8, 42 }; 43 44 std::string ToString(Axis axis); 45 46 // Layout represents axis order. 47 enum class Layout { 48 UNKNOWN = 0, 49 SCALAR = 1, 50 LINEAR = 2, 51 HW = 3, 52 CHW = 4, 53 HWC = 5, 54 OIHW = 6, 55 OHWI = 7, 56 IHWO = 8, 57 IOHW = 9, 58 BHWC = 10, 59 HWDC = 11, 60 BHWDC = 12, 61 HWD = 13, 62 OHWDI = 14, 63 HWIO = 15, 64 }; 65 66 std::string ToString(Layout l); 67 68 // Returns number of axis for the fixed layout. 69 template <Layout T> 70 constexpr int Size(); 71 72 // Returns number of axis for the given layout. 73 int Size(Layout layout); 74 75 // Returns Axis for the given index and fixed layout. 76 template <Layout T> 77 constexpr Axis GetAxis(int index); 78 79 // Returns axis for the given layout and index. 80 Axis GetAxis(Layout layout, int32_t index); 81 82 // Returns axis index for the given axis and fixed layout. 83 template <Layout T> 84 constexpr int GetAxisIndex(Axis axis); 85 86 // Returns axis index for the given layout and axis. 87 int GetAxisIndex(Layout layout, Axis axis); 88 89 // Checks if fixed layout has given axis 90 template <Layout T> 91 constexpr bool HasAxis(Axis axis); 92 93 // Checks if given layout has given axis 94 bool HasAxis(Layout layout, Axis axis); 95 96 // Stores Layout(axis set and order) and value for dimensions. 97 struct Shape { ShapeShape98 Shape() : layout(Layout::UNKNOWN), dimensions() {} 99 ShapeShape100 explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {} 101 ShapeShape102 Shape(Layout t, std::vector<int32_t> d) 103 : layout(t), dimensions(std::move(d)) {} 104 105 bool operator==(const Shape& other) const { 106 return (layout == other.layout) && (dimensions == other.dimensions); 107 } 108 109 bool operator!=(const Shape& other) const { return !operator==(other); } 110 111 // All methods below are matching same methods defined in StrongShape to 112 // make sure generic algorithms work both ways. 113 114 // Returns back a dimension or -1 if it is not found. 115 template <Axis D> 116 int32_t get() const; 117 int32_t get(Axis axis) const; 118 119 template <Axis D> 120 bool set(int32_t t); 121 bool set(Axis axis, int32_t t); 122 axisShape123 Axis axis(int index) const { return GetAxis(layout, index); } 124 indexShape125 int index(Axis axis) const { return GetAxisIndex(layout, axis); } 126 hasShape127 bool has(Axis axis) const { return HasAxis(layout, axis); } 128 DimensionsProductShape129 int64_t DimensionsProduct() const { 130 return std::accumulate(dimensions.begin(), dimensions.end(), 1LL, 131 std::multiplies<int64_t>()); 132 } 133 134 Layout layout = Layout::UNKNOWN; 135 136 std::vector<int32_t> dimensions; 137 }; 138 139 std::string ToString(const Shape& s); 140 141 // StrongShape provides convenient explicit access to dimensions stored in 142 // shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors. 143 // 144 // There is a conversion possible both ways between Shape and StrongShape. 145 // 146 // OIHW oihw; // specific shape 147 // Shape l = oihw.ToShape(); 148 // 149 // OHWI other; // notice not the same but compatible shape. 150 // if (!other.Adopt(l)) { 151 // // error handling 152 // } 153 // 154 // StrongShape supports the following set of operations: 155 // 156 // // Returns number of axis in the shape class. 157 // static constexpr int size(); 158 // 159 // // Returns Axis for the given index or Axis::UNKNOWN if index 160 // // falls outside of the defined range in this shape. 161 // static constexpr Axis axis(int index); 162 // 163 // // Returns index for the given axis or -1 if axis is not defined in this 164 // // shape. 165 // static constexpr int index(Axis axis); 166 // 167 // // Getters 168 // int32_t get(int index) const; 169 // int32_t get(Axis axis) const; 170 // int32_t get<Axis>() const; 171 // 172 // // Setters that return false if set was not successful. 173 // bool set(int index, int32_t v); 174 // bool set(Axis axis, int32_t v); 175 // bool set<Axis>(int32_t v); 176 // 177 // // Returns shape's layout. 178 // static const Layout layout; 179 // 180 // // Turns specific shape into generic shape. 181 // Shape ToShape() const; 182 // 183 // // Copies all dimensions from the given shape. 184 // bool Adopt(const Shape&); 185 // 186 template <Layout L> 187 struct StrongShape; 188 189 using Scalar = StrongShape<Layout::SCALAR>; 190 using Linear = StrongShape<Layout::LINEAR>; 191 using HW = StrongShape<Layout::HW>; 192 using HWD = StrongShape<Layout::HWD>; 193 194 // Common tensor shape for CNN models working with images. 195 using CHW = StrongShape<Layout::CHW>; 196 using HWC = StrongShape<Layout::HWC>; 197 using HWDC = StrongShape<Layout::HWDC>; 198 using BHWC = StrongShape<Layout::BHWC>; 199 using BHWDC = StrongShape<Layout::BHWDC>; 200 201 // Tensor shape used in convolution_2d weights. 202 using OIHW = StrongShape<Layout::OIHW>; 203 using OHWI = StrongShape<Layout::OHWI>; 204 using IHWO = StrongShape<Layout::IHWO>; 205 using IOHW = StrongShape<Layout::IOHW>; 206 using HWIO = StrongShape<Layout::HWIO>; 207 208 // Tensor shape used in convolution_3d weights. 209 using OHWDI = StrongShape<Layout::OHWDI>; 210 211 // ----------------------------------------------------------------------------- 212 // Everything below are internal implementation details. 213 // ----------------------------------------------------------------------------- 214 215 namespace internal_shape { 216 217 template <Axis T> 218 struct AxisTraits; 219 220 #define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName) \ 221 template <> \ 222 struct AxisTraits<Axis::AxisName> { \ 223 struct Holder { \ 224 int32_t HolderName; \ 225 \ 226 protected: \ 227 int32_t operator()() const { return HolderName; } \ 228 void operator()(int32_t v) { HolderName = v; } \ 229 }; \ 230 \ 231 using dimension_holder_type = Holder; \ 232 } 233 234 TFLITE_GPU_AXIS_TRAITS(CHANNELS, c); 235 TFLITE_GPU_AXIS_TRAITS(HEIGHT, h); 236 TFLITE_GPU_AXIS_TRAITS(WIDTH, w); 237 TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i); 238 TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o); 239 TFLITE_GPU_AXIS_TRAITS(BATCH, b); 240 TFLITE_GPU_AXIS_TRAITS(VALUE, v); 241 TFLITE_GPU_AXIS_TRAITS(DEPTH, d); 242 243 #undef TFLITE_GPU_AXIS_TRAITS 244 245 template <int N, Axis... As> 246 struct StrongShapeImpl; 247 248 template <int N> 249 struct StrongShapeImpl<N> { 250 static constexpr int size() { return N; } 251 252 static constexpr Axis axis(int) { return Axis::UNKNOWN; } 253 254 static constexpr int index(Axis) { return -1; } 255 256 static constexpr bool has(Axis) { return false; } 257 258 int32_t get(Axis) const { return -1; } 259 260 int32_t get(int) const { return -1; } 261 262 template <Axis B> 263 int32_t get() const { 264 return -1; 265 } 266 267 bool set(Axis, int32_t) { return false; } 268 269 bool set(int, int32_t) { return false; } 270 271 template <Axis B> 272 bool set(int32_t) { 273 return false; 274 } 275 }; 276 277 // Used to deduce number of axis, and to be a child of a proper holder to 278 // provide access to the dimension by name 279 template <int N, Axis A, Axis... As> 280 struct StrongShapeImpl<N, A, As...> 281 : public AxisTraits<A>::dimension_holder_type, 282 public StrongShapeImpl<N + 1, As...> { 283 using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type; 284 285 using rest_type = StrongShapeImpl<N + 1, As...>; 286 287 StrongShapeImpl() : dimension_holder_type{0}, rest_type() {} 288 289 template <typename... Ts> 290 explicit StrongShapeImpl(int32_t t, Ts... ts) 291 : dimension_holder_type{t}, rest_type(ts...) {} 292 293 static constexpr Axis axis(int index) { 294 return index == N ? A : rest_type::axis(index); 295 } 296 297 static constexpr int index(Axis axis) { 298 return axis == A ? N : rest_type::index(axis); 299 } 300 301 static constexpr bool has(Axis axis) { 302 return axis == A ? true : rest_type::has(axis); 303 } 304 305 int32_t get(Axis axis) const { 306 return axis == A ? dimension_holder_type::operator()() 307 : rest_type::get(axis); 308 } 309 310 template <Axis B> 311 int32_t get() const { 312 return B == A ? dimension_holder_type::operator()() 313 : rest_type::template get<B>(); 314 } 315 316 int32_t get(int index) const { 317 return index == N ? dimension_holder_type::operator()() 318 : rest_type::get(index); 319 } 320 321 bool set(Axis axis, int32_t t) { 322 if (axis == A) { 323 dimension_holder_type::operator()(t); 324 return true; 325 } 326 return rest_type::set(axis, t); 327 } 328 329 bool set(int index, int32_t t) { 330 if (index == N) { 331 dimension_holder_type::operator()(t); 332 return true; 333 } 334 return rest_type::set(index, t); 335 } 336 337 template <Axis B> 338 bool set(int32_t t) { 339 if (A == B) { 340 dimension_holder_type::operator()(t); 341 return true; 342 } 343 return rest_type::template set<B>(t); 344 } 345 }; 346 347 template <Layout T> 348 struct LayoutTraits; 349 350 #define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...) \ 351 template <> \ 352 struct LayoutTraits<Layout::LayoutName> { \ 353 using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \ 354 } 355 356 TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH); 357 TFLITE_GPU_LAYOUT_TRAITS(HWD, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH); 358 TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, 359 Axis::INPUT_CHANNELS); 360 TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS, 361 Axis::HEIGHT, Axis::WIDTH); 362 TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS, 363 Axis::HEIGHT, Axis::WIDTH); 364 TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, 365 Axis::OUTPUT_CHANNELS); 366 TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH); 367 TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS); 368 TFLITE_GPU_LAYOUT_TRAITS(HWDC, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH, 369 Axis::CHANNELS); 370 TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE); 371 TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE); 372 TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, 373 Axis::CHANNELS); 374 TFLITE_GPU_LAYOUT_TRAITS(BHWDC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, 375 Axis::DEPTH, Axis::CHANNELS); 376 TFLITE_GPU_LAYOUT_TRAITS(OHWDI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, 377 Axis::WIDTH, Axis::DEPTH, Axis::INPUT_CHANNELS); 378 TFLITE_GPU_LAYOUT_TRAITS(HWIO, Axis::HEIGHT, Axis::WIDTH, Axis::INPUT_CHANNELS, 379 Axis::OUTPUT_CHANNELS); 380 381 #undef TFLITE_GPU_LAYOUT_TRAITS 382 383 template <> 384 struct LayoutTraits<Layout::UNKNOWN> { 385 using strong_shape_type = StrongShapeImpl<0>; 386 }; 387 388 template <Axis A> 389 struct DimensionGetterFixedAxisFunc { 390 template <Layout T> 391 int32_t operator()() const { 392 constexpr int i = GetAxisIndex<T>(A); 393 return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; 394 } 395 const Shape* l; 396 }; 397 398 struct DimensionGetterFunc { 399 template <Layout T> 400 int32_t operator()() const { 401 int i = GetAxisIndex<T>(axis); 402 return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; 403 } 404 Axis axis; 405 const Shape* l; 406 }; 407 408 template <Axis A> 409 struct DimensionSetterFixedAxisFunc { 410 template <Layout T> 411 bool operator()() const { 412 constexpr int i = GetAxisIndex<T>(A); 413 if (i >= 0 && i < l->dimensions.size()) { 414 l->dimensions[i] = v; 415 return true; 416 } 417 return false; 418 } 419 Shape* l; 420 int32_t v; 421 }; 422 423 struct DimensionSetterFunc { 424 template <Layout T> 425 bool operator()() const { 426 int i = GetAxisIndex<T>(axis); 427 if (i >= 0 && i < l->dimensions.size()) { 428 l->dimensions[i] = v; 429 return true; 430 } 431 return false; 432 } 433 Axis axis; 434 Shape* l; 435 int32_t v; 436 }; 437 438 template <Layout L> 439 struct ToShapeFunc { 440 template <Layout T> 441 bool operator()() const { 442 for (int i = 0; i < StrongShape<L>::size(); ++i) { 443 int index = GetAxisIndex<T>(StrongShape<L>::axis(i)); 444 if (index < 0) return false; 445 shape->set(i, l.dimensions[index]); 446 } 447 return true; 448 } 449 450 StrongShape<L>* shape; 451 const Shape& l; 452 }; 453 454 } // namespace internal_shape 455 456 // template <Axis... As> 457 template <Layout L> 458 struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type { 459 using strong_shape_type = 460 typename internal_shape::LayoutTraits<L>::strong_shape_type; 461 StrongShape() = default; 462 463 template <typename... Ts> 464 explicit StrongShape(Ts... t) : strong_shape_type(t...) {} 465 466 constexpr static Layout layout = L; 467 468 bool operator==(const StrongShape<L>& shape) const { 469 // TODO(akulik): implement better alternative. 470 return this->ToShape() == shape.ToShape(); 471 } 472 473 bool operator!=(const StrongShape<L>& shape) const { 474 // TODO(akulik): implement better alternative. 475 return this->ToShape() != shape.ToShape(); 476 } 477 bool empty() const { return DimensionsProduct() == 0; } 478 479 // Turns StrongShape into generic shape. 480 Shape ToShape() const { 481 std::vector<int32_t> dimensions(StrongShape::size()); 482 for (int i = 0; i < StrongShape::size(); ++i) { 483 dimensions[i] = StrongShape::get(i); 484 } 485 return Shape(L, std::move(dimensions)); 486 } 487 488 // @return all dimensions multiplied 489 int64_t DimensionsProduct() const { 490 int64_t product = 1; 491 for (int i = 0; i < StrongShape::size(); ++i) { 492 product *= StrongShape::get(i); 493 } 494 return product; 495 } 496 497 // Translates given coordinates of the layout into a linear index assuming 498 // dimensions are sorted in tensor access order e.g. if you access 499 // foobar[i][j][k] order of coordinates should be i,j,k. 500 int64_t LinearIndex( 501 const std::array<int32_t, StrongShape::size()>& coordinates) const { 502 int64_t index = coordinates[0]; 503 for (int i = 1; i < StrongShape::size(); ++i) { 504 index = index * StrongShape::get(i) + coordinates[i]; 505 } 506 return index; 507 } 508 509 // Copies all dimensions from the given generic shape into specific shape. 510 // It requires shape to have all axis defined in the given 511 // StrongShape. For example: 512 // - If this shape is OHWI but given shape is OIHW, Adopt will copy all 513 // dimensions and return true. 514 // - If this shape is OIHW but input shape is HW, Adopt will copy H and W 515 // dimensions and return true, but if this shape is HW and given shape 516 // OIHW, then Adopt will return false because not all axis are present in 517 // the input shape. 518 // 519 // @return false if generic shape is not compatible. 520 bool Adopt(const Shape& shape) { 521 return DispatchByLayout(shape.layout, 522 internal_shape::ToShapeFunc<L>{this, shape}); 523 } 524 525 // For all axis defined in a given shape copies values to this shape. 526 // Therefore, it is possible to copy dimensions from CHW to BCHW, but not 527 // the other way around. 528 // 529 // BCHW bchw; 530 // CHW chw; 531 // bchw.CopyAllGivenAxis(chw); --> true 532 // chw.CopyAllGivenAxis(bchw); --> false 533 // 534 // @return false if axis in source shape is not defined here, thus value 535 // was not copied. 536 template <Layout B> 537 bool CopyAllGivenAxis(const StrongShape<B>& source) { 538 for (int i = 0; i < source.size(); ++i) { 539 if (!StrongShape::set(source.axis(i), source.get(i))) { 540 return false; 541 } 542 } 543 return true; 544 } 545 546 // For all axis defined in this shape copies values from the given shape. 547 // 548 // BCHW bchw; 549 // CHW chw; 550 // bchw.CopyAllDefinedAxis(chw); --> false 551 // chw.CopyAllDefinedAxis(bchw); --> true 552 // 553 // @return false if given shape does not have axis defined here, 554 // therefore a value was not copied. 555 template <Layout B> 556 bool CopyAllDefinedAxis(const StrongShape<B>& source) { 557 for (int i = 0; i < StrongShape::size(); ++i) { 558 int source_index = source.index(StrongShape::axis(i)); 559 if (source_index < 0) { 560 return false; 561 } 562 StrongShape::set(i, source.get(source_index)); // always true 563 } 564 return true; 565 } 566 567 // Copies values only for matching axis. 568 template <Layout B> 569 void CopyMatchingAxis(const StrongShape<B>& source) { 570 for (int i = 0; i < StrongShape::size(); ++i) { 571 StrongShape::set(source.axis(i), source.get(i)); 572 } 573 } 574 575 // AbslHash function for using in flat hash containers. 576 template <typename H> 577 friend H AbslHashValue(H hash_state, const StrongShape& strong_shape) { 578 for (size_t i = 0; i < strong_shape.size(); ++i) { 579 hash_state = H::combine(std::move(hash_state), strong_shape.get(i)); 580 } 581 return hash_state; 582 } 583 }; 584 585 template <Layout T> 586 inline std::string ToString(const StrongShape<T>& s) { 587 return ToString(s.ToShape()); 588 } 589 590 template <Layout L> 591 constexpr Layout StrongShape<L>::layout; 592 593 template <class F> 594 auto DispatchByLayout(Layout type, F f) 595 -> decltype(f.template operator()<Layout::UNKNOWN>()) { 596 switch (type) { 597 case Layout::HW: 598 return f.template operator()<Layout::HW>(); 599 case Layout::HWD: 600 return f.template operator()<Layout::HWD>(); 601 case Layout::HWC: 602 return f.template operator()<Layout::HWC>(); 603 case Layout::HWDC: 604 return f.template operator()<Layout::HWDC>(); 605 case Layout::CHW: 606 return f.template operator()<Layout::CHW>(); 607 case Layout::OIHW: 608 return f.template operator()<Layout::OIHW>(); 609 case Layout::IOHW: 610 return f.template operator()<Layout::IOHW>(); 611 case Layout::OHWI: 612 return f.template operator()<Layout::OHWI>(); 613 case Layout::IHWO: 614 return f.template operator()<Layout::IHWO>(); 615 case Layout::LINEAR: 616 return f.template operator()<Layout::LINEAR>(); 617 case Layout::SCALAR: 618 return f.template operator()<Layout::SCALAR>(); 619 case Layout::BHWC: 620 return f.template operator()<Layout::BHWC>(); 621 case Layout::BHWDC: 622 return f.template operator()<Layout::BHWDC>(); 623 case Layout::OHWDI: 624 return f.template operator()<Layout::OHWDI>(); 625 case Layout::HWIO: 626 return f.template operator()<Layout::HWIO>(); 627 case Layout::UNKNOWN: 628 return f.template operator()<Layout::UNKNOWN>(); 629 } 630 } 631 632 template <Layout T> 633 constexpr int Size() { 634 return StrongShape<T>::size(); 635 } 636 637 template <Layout T> 638 constexpr Axis GetAxis(int index) { 639 return StrongShape<T>::axis(index); 640 } 641 642 template <Layout T> 643 constexpr int GetAxisIndex(Axis axis) { 644 return StrongShape<T>::index(axis); 645 } 646 647 template <Layout T> 648 constexpr bool HasAxis(Axis axis) { 649 return StrongShape<T>::has(axis); 650 } 651 652 template <Axis D> 653 inline int32_t Shape::get() const { 654 return DispatchByLayout( 655 layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this}); 656 } 657 658 inline int32_t Shape::get(Axis axis) const { 659 return DispatchByLayout(layout, 660 internal_shape::DimensionGetterFunc{axis, this}); 661 } 662 663 template <Axis D> 664 inline bool Shape::set(int32_t t) { 665 return DispatchByLayout( 666 layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t}); 667 } 668 669 inline bool Shape::set(Axis axis, int32_t t) { 670 return DispatchByLayout(layout, 671 internal_shape::DimensionSetterFunc{axis, this, t}); 672 } 673 674 template <Layout T> 675 std::ostream& operator<<(std::ostream& ostream, const StrongShape<T>& shape) { 676 ostream << ToString(shape); 677 return ostream; 678 } 679 680 } // namespace gpu 681 } // namespace tflite 682 683 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 684