xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/shape.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 #include <array>
23 #include <functional>
24 #include <numeric>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 namespace tflite {
30 namespace gpu {
31 
32 enum class Axis {
33   UNKNOWN = 0,
34   CHANNELS = 1,
35   INPUT_CHANNELS = 2,
36   OUTPUT_CHANNELS = 3,
37   HEIGHT = 4,
38   WIDTH = 5,
39   BATCH = 6,
40   VALUE = 7,
41   DEPTH = 8,
42 };
43 
44 std::string ToString(Axis axis);
45 
46 // Layout represents axis order.
47 enum class Layout {
48   UNKNOWN = 0,
49   SCALAR = 1,
50   LINEAR = 2,
51   HW = 3,
52   CHW = 4,
53   HWC = 5,
54   OIHW = 6,
55   OHWI = 7,
56   IHWO = 8,
57   IOHW = 9,
58   BHWC = 10,
59   HWDC = 11,
60   BHWDC = 12,
61   HWD = 13,
62   OHWDI = 14,
63   HWIO = 15,
64 };
65 
66 std::string ToString(Layout l);
67 
68 // Returns number of axis for the fixed layout.
69 template <Layout T>
70 constexpr int Size();
71 
72 // Returns number of axis for the given layout.
73 int Size(Layout layout);
74 
75 // Returns Axis for the given index and fixed layout.
76 template <Layout T>
77 constexpr Axis GetAxis(int index);
78 
79 // Returns axis for the given layout and index.
80 Axis GetAxis(Layout layout, int32_t index);
81 
82 // Returns axis index for the given axis and fixed layout.
83 template <Layout T>
84 constexpr int GetAxisIndex(Axis axis);
85 
86 // Returns axis index for the given layout and axis.
87 int GetAxisIndex(Layout layout, Axis axis);
88 
89 // Checks if fixed layout has given axis
90 template <Layout T>
91 constexpr bool HasAxis(Axis axis);
92 
93 // Checks if given layout has given axis
94 bool HasAxis(Layout layout, Axis axis);
95 
96 // Stores Layout(axis set and order) and value for dimensions.
97 struct Shape {
ShapeShape98   Shape() : layout(Layout::UNKNOWN), dimensions() {}
99 
ShapeShape100   explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {}
101 
ShapeShape102   Shape(Layout t, std::vector<int32_t> d)
103       : layout(t), dimensions(std::move(d)) {}
104 
105   bool operator==(const Shape& other) const {
106     return (layout == other.layout) && (dimensions == other.dimensions);
107   }
108 
109   bool operator!=(const Shape& other) const { return !operator==(other); }
110 
111   // All methods below are matching same methods defined in StrongShape to
112   // make sure generic algorithms work both ways.
113 
114   // Returns back a dimension or -1 if it is not found.
115   template <Axis D>
116   int32_t get() const;
117   int32_t get(Axis axis) const;
118 
119   template <Axis D>
120   bool set(int32_t t);
121   bool set(Axis axis, int32_t t);
122 
axisShape123   Axis axis(int index) const { return GetAxis(layout, index); }
124 
indexShape125   int index(Axis axis) const { return GetAxisIndex(layout, axis); }
126 
hasShape127   bool has(Axis axis) const { return HasAxis(layout, axis); }
128 
DimensionsProductShape129   int64_t DimensionsProduct() const {
130     return std::accumulate(dimensions.begin(), dimensions.end(), 1LL,
131                            std::multiplies<int64_t>());
132   }
133 
134   Layout layout = Layout::UNKNOWN;
135 
136   std::vector<int32_t> dimensions;
137 };
138 
139 std::string ToString(const Shape& s);
140 
141 // StrongShape provides convenient explicit access to dimensions stored in
142 // shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors.
143 //
144 // There is a conversion possible both ways between Shape and StrongShape.
145 //
146 //   OIHW oihw;  // specific shape
147 //   Shape l = oihw.ToShape();
148 //
149 //   OHWI other;  // notice not the same but compatible shape.
150 //   if (!other.Adopt(l)) {
151 //     // error handling
152 //   }
153 //
154 // StrongShape supports the following set of operations:
155 //
156 //   // Returns number of axis in the shape class.
157 //   static constexpr int size();
158 //
159 //   // Returns Axis for the given index or Axis::UNKNOWN if index
160 //   // falls outside of the defined range in this shape.
161 //   static constexpr Axis axis(int index);
162 //
163 //   // Returns index for the given axis or -1 if axis is not defined in this
164 //   // shape.
165 //   static constexpr int index(Axis axis);
166 //
167 //   // Getters
168 //   int32_t get(int index) const;
169 //   int32_t get(Axis axis) const;
170 //   int32_t get<Axis>() const;
171 //
172 //   // Setters that return false if set was not successful.
173 //   bool set(int index, int32_t v);
174 //   bool set(Axis axis, int32_t v);
175 //   bool set<Axis>(int32_t v);
176 //
177 //   // Returns shape's layout.
178 //   static const Layout layout;
179 //
180 //   // Turns specific shape into generic shape.
181 //   Shape ToShape() const;
182 //
183 //   // Copies all dimensions from the given shape.
184 //   bool Adopt(const Shape&);
185 //
186 template <Layout L>
187 struct StrongShape;
188 
189 using Scalar = StrongShape<Layout::SCALAR>;
190 using Linear = StrongShape<Layout::LINEAR>;
191 using HW = StrongShape<Layout::HW>;
192 using HWD = StrongShape<Layout::HWD>;
193 
194 // Common tensor shape for CNN models working with images.
195 using CHW = StrongShape<Layout::CHW>;
196 using HWC = StrongShape<Layout::HWC>;
197 using HWDC = StrongShape<Layout::HWDC>;
198 using BHWC = StrongShape<Layout::BHWC>;
199 using BHWDC = StrongShape<Layout::BHWDC>;
200 
201 // Tensor shape used in convolution_2d weights.
202 using OIHW = StrongShape<Layout::OIHW>;
203 using OHWI = StrongShape<Layout::OHWI>;
204 using IHWO = StrongShape<Layout::IHWO>;
205 using IOHW = StrongShape<Layout::IOHW>;
206 using HWIO = StrongShape<Layout::HWIO>;
207 
208 // Tensor shape used in convolution_3d weights.
209 using OHWDI = StrongShape<Layout::OHWDI>;
210 
211 // -----------------------------------------------------------------------------
212 // Everything below are internal implementation details.
213 // -----------------------------------------------------------------------------
214 
215 namespace internal_shape {
216 
217 template <Axis T>
218 struct AxisTraits;
219 
220 #define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName)    \
221   template <>                                           \
222   struct AxisTraits<Axis::AxisName> {                   \
223     struct Holder {                                     \
224       int32_t HolderName;                               \
225                                                         \
226      protected:                                         \
227       int32_t operator()() const { return HolderName; } \
228       void operator()(int32_t v) { HolderName = v; }    \
229     };                                                  \
230                                                         \
231     using dimension_holder_type = Holder;               \
232   }
233 
234 TFLITE_GPU_AXIS_TRAITS(CHANNELS, c);
235 TFLITE_GPU_AXIS_TRAITS(HEIGHT, h);
236 TFLITE_GPU_AXIS_TRAITS(WIDTH, w);
237 TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i);
238 TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o);
239 TFLITE_GPU_AXIS_TRAITS(BATCH, b);
240 TFLITE_GPU_AXIS_TRAITS(VALUE, v);
241 TFLITE_GPU_AXIS_TRAITS(DEPTH, d);
242 
243 #undef TFLITE_GPU_AXIS_TRAITS
244 
245 template <int N, Axis... As>
246 struct StrongShapeImpl;
247 
248 template <int N>
249 struct StrongShapeImpl<N> {
250   static constexpr int size() { return N; }
251 
252   static constexpr Axis axis(int) { return Axis::UNKNOWN; }
253 
254   static constexpr int index(Axis) { return -1; }
255 
256   static constexpr bool has(Axis) { return false; }
257 
258   int32_t get(Axis) const { return -1; }
259 
260   int32_t get(int) const { return -1; }
261 
262   template <Axis B>
263   int32_t get() const {
264     return -1;
265   }
266 
267   bool set(Axis, int32_t) { return false; }
268 
269   bool set(int, int32_t) { return false; }
270 
271   template <Axis B>
272   bool set(int32_t) {
273     return false;
274   }
275 };
276 
277 // Used to deduce number of axis, and to be a child of a proper holder to
278 // provide access to the dimension by name
279 template <int N, Axis A, Axis... As>
280 struct StrongShapeImpl<N, A, As...>
281     : public AxisTraits<A>::dimension_holder_type,
282       public StrongShapeImpl<N + 1, As...> {
283   using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type;
284 
285   using rest_type = StrongShapeImpl<N + 1, As...>;
286 
287   StrongShapeImpl() : dimension_holder_type{0}, rest_type() {}
288 
289   template <typename... Ts>
290   explicit StrongShapeImpl(int32_t t, Ts... ts)
291       : dimension_holder_type{t}, rest_type(ts...) {}
292 
293   static constexpr Axis axis(int index) {
294     return index == N ? A : rest_type::axis(index);
295   }
296 
297   static constexpr int index(Axis axis) {
298     return axis == A ? N : rest_type::index(axis);
299   }
300 
301   static constexpr bool has(Axis axis) {
302     return axis == A ? true : rest_type::has(axis);
303   }
304 
305   int32_t get(Axis axis) const {
306     return axis == A ? dimension_holder_type::operator()()
307                      : rest_type::get(axis);
308   }
309 
310   template <Axis B>
311   int32_t get() const {
312     return B == A ? dimension_holder_type::operator()()
313                   : rest_type::template get<B>();
314   }
315 
316   int32_t get(int index) const {
317     return index == N ? dimension_holder_type::operator()()
318                       : rest_type::get(index);
319   }
320 
321   bool set(Axis axis, int32_t t) {
322     if (axis == A) {
323       dimension_holder_type::operator()(t);
324       return true;
325     }
326     return rest_type::set(axis, t);
327   }
328 
329   bool set(int index, int32_t t) {
330     if (index == N) {
331       dimension_holder_type::operator()(t);
332       return true;
333     }
334     return rest_type::set(index, t);
335   }
336 
337   template <Axis B>
338   bool set(int32_t t) {
339     if (A == B) {
340       dimension_holder_type::operator()(t);
341       return true;
342     }
343     return rest_type::template set<B>(t);
344   }
345 };
346 
347 template <Layout T>
348 struct LayoutTraits;
349 
350 #define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...)              \
351   template <>                                                  \
352   struct LayoutTraits<Layout::LayoutName> {                    \
353     using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \
354   }
355 
356 TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH);
357 TFLITE_GPU_LAYOUT_TRAITS(HWD, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH);
358 TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
359                          Axis::INPUT_CHANNELS);
360 TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS,
361                          Axis::HEIGHT, Axis::WIDTH);
362 TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS,
363                          Axis::HEIGHT, Axis::WIDTH);
364 TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
365                          Axis::OUTPUT_CHANNELS);
366 TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH);
367 TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS);
368 TFLITE_GPU_LAYOUT_TRAITS(HWDC, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH,
369                          Axis::CHANNELS);
370 TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE);
371 TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE);
372 TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH,
373                          Axis::CHANNELS);
374 TFLITE_GPU_LAYOUT_TRAITS(BHWDC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH,
375                          Axis::DEPTH, Axis::CHANNELS);
376 TFLITE_GPU_LAYOUT_TRAITS(OHWDI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT,
377                          Axis::WIDTH, Axis::DEPTH, Axis::INPUT_CHANNELS);
378 TFLITE_GPU_LAYOUT_TRAITS(HWIO, Axis::HEIGHT, Axis::WIDTH, Axis::INPUT_CHANNELS,
379                          Axis::OUTPUT_CHANNELS);
380 
381 #undef TFLITE_GPU_LAYOUT_TRAITS
382 
383 template <>
384 struct LayoutTraits<Layout::UNKNOWN> {
385   using strong_shape_type = StrongShapeImpl<0>;
386 };
387 
388 template <Axis A>
389 struct DimensionGetterFixedAxisFunc {
390   template <Layout T>
391   int32_t operator()() const {
392     constexpr int i = GetAxisIndex<T>(A);
393     return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
394   }
395   const Shape* l;
396 };
397 
398 struct DimensionGetterFunc {
399   template <Layout T>
400   int32_t operator()() const {
401     int i = GetAxisIndex<T>(axis);
402     return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
403   }
404   Axis axis;
405   const Shape* l;
406 };
407 
408 template <Axis A>
409 struct DimensionSetterFixedAxisFunc {
410   template <Layout T>
411   bool operator()() const {
412     constexpr int i = GetAxisIndex<T>(A);
413     if (i >= 0 && i < l->dimensions.size()) {
414       l->dimensions[i] = v;
415       return true;
416     }
417     return false;
418   }
419   Shape* l;
420   int32_t v;
421 };
422 
423 struct DimensionSetterFunc {
424   template <Layout T>
425   bool operator()() const {
426     int i = GetAxisIndex<T>(axis);
427     if (i >= 0 && i < l->dimensions.size()) {
428       l->dimensions[i] = v;
429       return true;
430     }
431     return false;
432   }
433   Axis axis;
434   Shape* l;
435   int32_t v;
436 };
437 
438 template <Layout L>
439 struct ToShapeFunc {
440   template <Layout T>
441   bool operator()() const {
442     for (int i = 0; i < StrongShape<L>::size(); ++i) {
443       int index = GetAxisIndex<T>(StrongShape<L>::axis(i));
444       if (index < 0) return false;
445       shape->set(i, l.dimensions[index]);
446     }
447     return true;
448   }
449 
450   StrongShape<L>* shape;
451   const Shape& l;
452 };
453 
454 }  // namespace internal_shape
455 
456 // template <Axis... As>
457 template <Layout L>
458 struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type {
459   using strong_shape_type =
460       typename internal_shape::LayoutTraits<L>::strong_shape_type;
461   StrongShape() = default;
462 
463   template <typename... Ts>
464   explicit StrongShape(Ts... t) : strong_shape_type(t...) {}
465 
466   constexpr static Layout layout = L;
467 
468   bool operator==(const StrongShape<L>& shape) const {
469     // TODO(akulik): implement better alternative.
470     return this->ToShape() == shape.ToShape();
471   }
472 
473   bool operator!=(const StrongShape<L>& shape) const {
474     // TODO(akulik): implement better alternative.
475     return this->ToShape() != shape.ToShape();
476   }
477   bool empty() const { return DimensionsProduct() == 0; }
478 
479   // Turns StrongShape into generic shape.
480   Shape ToShape() const {
481     std::vector<int32_t> dimensions(StrongShape::size());
482     for (int i = 0; i < StrongShape::size(); ++i) {
483       dimensions[i] = StrongShape::get(i);
484     }
485     return Shape(L, std::move(dimensions));
486   }
487 
488   // @return all dimensions multiplied
489   int64_t DimensionsProduct() const {
490     int64_t product = 1;
491     for (int i = 0; i < StrongShape::size(); ++i) {
492       product *= StrongShape::get(i);
493     }
494     return product;
495   }
496 
497   // Translates given coordinates of the layout into a linear index assuming
498   // dimensions are sorted in tensor access order e.g. if you access
499   // foobar[i][j][k] order of coordinates should be i,j,k.
500   int64_t LinearIndex(
501       const std::array<int32_t, StrongShape::size()>& coordinates) const {
502     int64_t index = coordinates[0];
503     for (int i = 1; i < StrongShape::size(); ++i) {
504       index = index * StrongShape::get(i) + coordinates[i];
505     }
506     return index;
507   }
508 
509   // Copies all dimensions from the given generic shape into specific shape.
510   // It requires shape to have all axis defined in the given
511   // StrongShape. For example:
512   //   - If this shape is OHWI but given shape is OIHW, Adopt will copy all
513   //     dimensions and return true.
514   //   - If this shape is OIHW but input shape is HW, Adopt will copy H and W
515   //     dimensions and return true, but if this shape is HW and given shape
516   //     OIHW, then Adopt will return false because not all axis are present in
517   //     the input shape.
518   //
519   // @return false if generic shape is not compatible.
520   bool Adopt(const Shape& shape) {
521     return DispatchByLayout(shape.layout,
522                             internal_shape::ToShapeFunc<L>{this, shape});
523   }
524 
525   // For all axis defined in a given shape copies values to this shape.
526   // Therefore, it is possible to copy dimensions from CHW to BCHW, but not
527   // the other way around.
528   //
529   // BCHW bchw;
530   // CHW chw;
531   // bchw.CopyAllGivenAxis(chw);  --> true
532   // chw.CopyAllGivenAxis(bchw);  --> false
533   //
534   // @return false if axis in source shape is not defined here, thus value
535   //         was not copied.
536   template <Layout B>
537   bool CopyAllGivenAxis(const StrongShape<B>& source) {
538     for (int i = 0; i < source.size(); ++i) {
539       if (!StrongShape::set(source.axis(i), source.get(i))) {
540         return false;
541       }
542     }
543     return true;
544   }
545 
546   // For all axis defined in this shape copies values from the given shape.
547   //
548   // BCHW bchw;
549   // CHW chw;
550   // bchw.CopyAllDefinedAxis(chw);  --> false
551   // chw.CopyAllDefinedAxis(bchw);  --> true
552   //
553   // @return false if given shape does not have axis defined here,
554   //         therefore a value was not copied.
555   template <Layout B>
556   bool CopyAllDefinedAxis(const StrongShape<B>& source) {
557     for (int i = 0; i < StrongShape::size(); ++i) {
558       int source_index = source.index(StrongShape::axis(i));
559       if (source_index < 0) {
560         return false;
561       }
562       StrongShape::set(i, source.get(source_index));  // always true
563     }
564     return true;
565   }
566 
567   // Copies values only for matching axis.
568   template <Layout B>
569   void CopyMatchingAxis(const StrongShape<B>& source) {
570     for (int i = 0; i < StrongShape::size(); ++i) {
571       StrongShape::set(source.axis(i), source.get(i));
572     }
573   }
574 
575   // AbslHash function for using in flat hash containers.
576   template <typename H>
577   friend H AbslHashValue(H hash_state, const StrongShape& strong_shape) {
578     for (size_t i = 0; i < strong_shape.size(); ++i) {
579       hash_state = H::combine(std::move(hash_state), strong_shape.get(i));
580     }
581     return hash_state;
582   }
583 };
584 
585 template <Layout T>
586 inline std::string ToString(const StrongShape<T>& s) {
587   return ToString(s.ToShape());
588 }
589 
590 template <Layout L>
591 constexpr Layout StrongShape<L>::layout;
592 
593 template <class F>
594 auto DispatchByLayout(Layout type, F f)
595     -> decltype(f.template operator()<Layout::UNKNOWN>()) {
596   switch (type) {
597     case Layout::HW:
598       return f.template operator()<Layout::HW>();
599     case Layout::HWD:
600       return f.template operator()<Layout::HWD>();
601     case Layout::HWC:
602       return f.template operator()<Layout::HWC>();
603     case Layout::HWDC:
604       return f.template operator()<Layout::HWDC>();
605     case Layout::CHW:
606       return f.template operator()<Layout::CHW>();
607     case Layout::OIHW:
608       return f.template operator()<Layout::OIHW>();
609     case Layout::IOHW:
610       return f.template operator()<Layout::IOHW>();
611     case Layout::OHWI:
612       return f.template operator()<Layout::OHWI>();
613     case Layout::IHWO:
614       return f.template operator()<Layout::IHWO>();
615     case Layout::LINEAR:
616       return f.template operator()<Layout::LINEAR>();
617     case Layout::SCALAR:
618       return f.template operator()<Layout::SCALAR>();
619     case Layout::BHWC:
620       return f.template operator()<Layout::BHWC>();
621     case Layout::BHWDC:
622       return f.template operator()<Layout::BHWDC>();
623     case Layout::OHWDI:
624       return f.template operator()<Layout::OHWDI>();
625     case Layout::HWIO:
626       return f.template operator()<Layout::HWIO>();
627     case Layout::UNKNOWN:
628       return f.template operator()<Layout::UNKNOWN>();
629   }
630 }
631 
632 template <Layout T>
633 constexpr int Size() {
634   return StrongShape<T>::size();
635 }
636 
637 template <Layout T>
638 constexpr Axis GetAxis(int index) {
639   return StrongShape<T>::axis(index);
640 }
641 
642 template <Layout T>
643 constexpr int GetAxisIndex(Axis axis) {
644   return StrongShape<T>::index(axis);
645 }
646 
647 template <Layout T>
648 constexpr bool HasAxis(Axis axis) {
649   return StrongShape<T>::has(axis);
650 }
651 
652 template <Axis D>
653 inline int32_t Shape::get() const {
654   return DispatchByLayout(
655       layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this});
656 }
657 
658 inline int32_t Shape::get(Axis axis) const {
659   return DispatchByLayout(layout,
660                           internal_shape::DimensionGetterFunc{axis, this});
661 }
662 
663 template <Axis D>
664 inline bool Shape::set(int32_t t) {
665   return DispatchByLayout(
666       layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t});
667 }
668 
669 inline bool Shape::set(Axis axis, int32_t t) {
670   return DispatchByLayout(layout,
671                           internal_shape::DimensionSetterFunc{axis, this, t});
672 }
673 
674 template <Layout T>
675 std::ostream& operator<<(std::ostream& ostream, const StrongShape<T>& shape) {
676   ostream << ToString(shape);
677   return ostream;
678 }
679 
680 }  // namespace gpu
681 }  // namespace tflite
682 
683 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
684