xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
21 
22 namespace Eigen {
23 
24 /** SpatialConvolutionBackwardInput
25  * \ingroup CXX11_NeuralNetworks_Module
26  *
27  * \brief Computes the backprop for the input of a 2D convolution.
28  *
29  * The output_backward parameter is expected to be a tensor with a rank of 3 or
30  * more (channels, height, width, and optionally others)
31  * The kernel parameter is expected to be a 4D tensor (filters, channels,
32  * kernel_height, kernel_width)
33  * The output_backward and the kernel must both be in col-major layout. The
34  * result will also be in col-major layout.
35  *
36  * If row_in_stride, col_in_stride > 1, then applies convolution with holes
37  * (aka atrous convolution), sampling every row_in_stride, col_in_stride input
38  * pixels.
39  *
40  * The result can be assigned to a tensor of rank equal to the rank of the
41  * output_backward. The dimensions of the result will be filters, height, width
42  * (and others if applicable).
43  *
44  * It is possible to swap the order of the width and height dimensions provided
45  * that the same order is used in the input, the kernel, and the output.
46  *
47  */
48 typedef IndexList<type2index<0>, type2index<0>, type2index<1>, type2index<1>>
49     ReverseColMajor;
50 typedef IndexList<type2index<1>, type2index<1>, type2index<0>, type2index<0>>
51     ReverseRowMajor;
52 
53 template <typename OutputBackward, typename Kernel>
54 EIGEN_ALWAYS_INLINE static const std::conditional_t<
55     internal::traits<OutputBackward>::Layout == ColMajor,
56     TensorReshapingOp<
57         const DSizes<typename internal::traits<OutputBackward>::Index,
58                      internal::traits<OutputBackward>::NumDimensions>,
59         const TensorContractionOp<
60             const array<
61                 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
62             const TensorReshapingOp<
63                 const DSizes<typename internal::traits<OutputBackward>::Index,
64                              2>,
65                 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
66                     const array<
67                         typename internal::traits<OutputBackward>::Index, 4>,
68                     const Eigen::TensorForcedEvalOp<const TensorReverseOp<
69                         const ReverseColMajor, const Kernel>>>>>,
70             const TensorReshapingOp<
71                 const DSizes<typename internal::traits<OutputBackward>::Index,
72                              2>,
73                 const TensorImagePatchOp<Dynamic, Dynamic,
74                                          const OutputBackward>>>>,
75     TensorReshapingOp<
76 
77         const DSizes<typename internal::traits<OutputBackward>::Index,
78                      internal::traits<OutputBackward>::NumDimensions>,
79         const TensorContractionOp<
80             const array<
81                 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
82             const TensorReshapingOp<
83                 const DSizes<typename internal::traits<OutputBackward>::Index,
84                              2>,
85                 const TensorImagePatchOp<Dynamic, Dynamic,
86                                          const OutputBackward>>,
87             const TensorReshapingOp<
88                 const DSizes<typename internal::traits<OutputBackward>::Index,
89                              2>,
90                 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
91                     const array<
92                         typename internal::traits<OutputBackward>::Index, 4>,
93                     const Eigen::TensorForcedEvalOp<const TensorReverseOp<
94                         const ReverseRowMajor, const Kernel>>>>>>>>
95 SpatialConvolutionBackwardInput(
96     const Kernel& kernel, const OutputBackward& output_backward,
97     typename internal::traits<OutputBackward>::Index inputRows,
98     typename internal::traits<OutputBackward>::Index inputCols,
99     const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
100     const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
101   typedef typename internal::traits<OutputBackward>::Index TensorIndex;
102   typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
103   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
104                    internal::traits<Kernel>::NumDimensions,
105                    internal::traits<Kernel>::Layout, TensorIndex>>
106       kern(kernel);
107   TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
108                    internal::traits<OutputBackward>::Layout, TensorIndex>>
109       out(output_backward);
110 
111   EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
112                           internal::traits<OutputBackward>::Layout,
113                       YOU_MADE_A_PROGRAMMING_MISTAKE);
114 
115   static const bool isColMajor =
116       (internal::traits<OutputBackward>::Layout == ColMajor);
117 
118   static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
119 
120   // Number of filters to apply. This is the same as the output depth of the
121   // result
122   const TensorIndex kernelFilters =
123       isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
124   // Number of channels. This is the same as the input depth.
125   const TensorIndex kernelChannels =
126       isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
127   const TensorIndex kernelRows =
128       isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
129   const TensorIndex kernelCols =
130       isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
131 
132   // This is the effective kernel size, taking into account the (*_in_stride -
133   // 1) zero-values
134   // inserted between consecutive kernel elements in atrous convolution
135   const TensorIndex kernelRowsEff =
136       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
137   const TensorIndex kernelColsEff =
138       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
139 
140   const TensorIndex outputRows = isColMajor
141                                      ? output_backward.dimension(1)
142                                      : output_backward.dimension(NumDims - 2);
143   const TensorIndex outputCols = isColMajor
144                                      ? output_backward.dimension(2)
145                                      : output_backward.dimension(NumDims - 3);
146 
147   // Computing the forward padding
148   const TensorIndex forward_pad_top = numext::maxi<Index>(
149       0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
150   const TensorIndex forward_pad_left = numext::maxi<Index>(
151       0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
152   const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
153   const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
154 
155   const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
156                                      2 - padding_top + kernelRowsEff;
157   const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
158                                     2 - padding_left + kernelColsEff;
159 
160   eigen_assert(padding_top >= 0);
161   eigen_assert(padding_left >= 0);
162   eigen_assert(padding_bottom >= 0);
163   eigen_assert(padding_right >= 0);
164 
165   // The kernel has dimensions filters X channels X patch_rows X patch_cols
166   // We need to reverse the kernel along dimensions corresponding to rows and
167   // cols.
168   // TODO(yangke): we can make things slightly faster by collapsing the
169   // dimensions
170   // where we don't reverse. Try that once we have a faster compiler.
171   typedef std::conditional_t<isColMajor, ReverseColMajor, ReverseRowMajor>
172       Reverse;
173   Reverse kernel_reverse;
174   // Reorder the dimensions to:
175   //   filters x patch_rows x patch_cols x channels
176   array<TensorIndex, 4> kernel_shuffle;
177   if (isColMajor) {
178     //  From: filters x channels x rows x cols
179     //  To:   filters x rows x cols x channels
180     kernel_shuffle[0] = 0;
181     kernel_shuffle[1] = 2;
182     kernel_shuffle[2] = 3;
183     kernel_shuffle[3] = 1;
184   } else {
185     //  From: cols x rows x channels x filters
186     //  To:   channels x cols x rows x filters
187     kernel_shuffle[0] = 2;
188     kernel_shuffle[1] = 0;
189     kernel_shuffle[2] = 1;
190     kernel_shuffle[3] = 3;
191   }
192 
193   // Collapse the dims
194   DSizes<TensorIndex, 2> kernel_dims;
195   if (isColMajor) {
196     kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
197     kernel_dims[1] = kernelChannels;
198   } else {
199     kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
200     kernel_dims[0] = kernelChannels;
201   }
202 
203   // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
204   // When we extract the image patches from output_backward, it will have
205   // dimensions
206   //   out_depth X (patch_rows * patch_cols) X (input_rows * input_cols *
207   //   OTHERS)
208   DSizes<TensorIndex, 2> pre_contract_dims;
209   if (isColMajor) {
210     pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
211     pre_contract_dims[1] = inputRows * inputCols;
212     for (int i = 3; i < NumDims; ++i) {
213       pre_contract_dims[1] *= out.dimension(i);
214     }
215   } else {
216     pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
217     pre_contract_dims[0] = inputRows * inputCols;
218     for (int i = 0; i < NumDims - 3; ++i) {
219       pre_contract_dims[0] *= out.dimension(i);
220     }
221   }
222 
223   // We will contract along the collapsed dimension that contains the
224   // kernelFilters, the kernelRows and the kernelCols.
225   array<IndexPair<TensorIndex>, 1> contract_dims;
226   if (isColMajor) {
227     // col-major: kernel.contract(output.patches)
228     contract_dims[0] = IndexPair<TensorIndex>(0, 0);
229   } else {
230     // row-major: output.patches.contract(kernel)
231     contract_dims[0] = IndexPair<TensorIndex>(1, 1);
232   }
233 
234   // Post contraction, the dimensions of the input_backprop is
235   //  channels X input_rows X input_cols X OTHERS
236   DSizes<TensorIndex, NumDims> post_contract_dims;
237   if (isColMajor) {
238     post_contract_dims[0] = kernelChannels;
239     post_contract_dims[1] = inputRows;
240     post_contract_dims[2] = inputCols;
241     for (int i = 3; i < NumDims; ++i) {
242       post_contract_dims[i] = out.dimension(i);
243     }
244   } else {
245     post_contract_dims[NumDims - 1] = kernelChannels;
246     post_contract_dims[NumDims - 2] = inputRows;
247     post_contract_dims[NumDims - 3] = inputCols;
248     for (int i = 0; i < NumDims - 3; ++i) {
249       post_contract_dims[i] = out.dimension(i);
250     }
251   }
252 
253   // NOTE(ezhulenev): We do eval after reverse and shuffle, because tiled
254   // evaluation of these ops does not compose. Doing explicit eval is ~8x
255   // faster in micro benchmarks.
256 
257   return choose(
258       Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
259       kernel.reverse(kernel_reverse)
260           .eval()
261           .shuffle(kernel_shuffle)
262           .eval()
263           .reshape(kernel_dims)
264           .contract(
265               output_backward
266                   .extract_image_patches(
267                       kernelRows, kernelCols, 1, 1, row_in_stride,
268                       col_in_stride, row_stride, col_stride, padding_top,
269                       padding_bottom, padding_left, padding_right, OutScalar(0))
270                   .reshape(pre_contract_dims),
271               contract_dims)
272           .reshape(post_contract_dims),
273       output_backward
274           .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride,
275                                  col_in_stride, row_stride, col_stride,
276                                  padding_top, padding_bottom, padding_left,
277                                  padding_right, OutScalar(0))
278           .reshape(pre_contract_dims)
279           .contract(kernel.reverse(kernel_reverse)
280                         .eval()
281                         .shuffle(kernel_shuffle)
282                         .eval()
283                         .reshape(kernel_dims),
284                     contract_dims)
285           .reshape(post_contract_dims));
286 }
287 
288 /** SpatialConvolutionBackwardKernel
289  * \ingroup CXX11_NeuralNetworks_Module
290  *
291  * \brief Computes the backprop for the filter of a 2D convolution.
292  *
293  * The output_backward parameter is expected to be a tensor with a rank of 3 or
294  * more (channels, height, width, and optionally others)
295  * The kernel parameter is expected to be a 4D tensor (filters, channels,
296  * kernel_height, kernel_width)
297  * The output_backward and the kernel must both be in col-major layout. The
298  * result will also be in col-major layout.
299  *
300  * If row_in_stride, col_stride > 1, then applies convolution with holes (aka
301  * atrous convolution), sampling every row_in_stride, col_in_stride input
302  * pixels.
303  *
304  * The result can be assigned to a tensor of rank equal to the rank of the
305  * output_backward. The dimensions of the result will be filters, height, width
306  * (and others if applicable).
307  *
308  * It is possible to swap the order of the width and height dimensions provided
309  * that the same order is used in the input, the kernel, and the output.
310  *
311  */
312 
313 template <typename OutputBackward, typename Input>
314 EIGEN_ALWAYS_INLINE static const std::conditional_t<
315     internal::traits<Input>::Layout == ColMajor,
316     const TensorReverseOp<
317         const Eigen::array<typename internal::traits<Input>::Index,
318                            internal::traits<Input>::NumDimensions>,
319         const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
320             const Eigen::array<typename internal::traits<Input>::Index,
321                                internal::traits<Input>::NumDimensions>,
322             const Eigen::TensorReshapingOp<
323                 const Eigen::DSizes<typename internal::traits<Input>::Index,
324                                     internal::traits<Input>::NumDimensions>,
325                 const TensorContractionOp<
326                     const array<
327                         IndexPair<typename internal::traits<Input>::Index>, 1>,
328                     const TensorReshapingOp<
329                         const DSizes<typename internal::traits<Input>::Index,
330                                      2>,
331                         const Eigen::TensorForcedEvalOp<
332                             const Eigen::TensorShufflingOp<
333                                 const Eigen::array<
334                                     typename internal::traits<Input>::Index,
335                                     internal::traits<Input>::NumDimensions>,
336                                 const Input>>>,
337                     const TensorReshapingOp<
338                         const DSizes<typename internal::traits<Input>::Index,
339                                      2>,
340                         const TensorImagePatchOp<
341                             Dynamic, Dynamic,
342                             const Eigen::TensorForcedEvalOp<
343                                 const Eigen::TensorShufflingOp<
344                                     const Eigen::array<
345                                         typename internal::traits<Input>::Index,
346                                         internal::traits<Input>::NumDimensions>,
347                                     const OutputBackward>>>>>>>>>,
348     const TensorReverseOp<
349         const Eigen::array<typename internal::traits<Input>::Index,
350                            internal::traits<Input>::NumDimensions>,
351         const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
352             const Eigen::array<typename internal::traits<Input>::Index,
353                                internal::traits<Input>::NumDimensions>,
354             const Eigen::TensorReshapingOp<
355                 const Eigen::DSizes<typename internal::traits<Input>::Index,
356                                     internal::traits<Input>::NumDimensions>,
357                 const TensorContractionOp<
358                     const array<
359                         IndexPair<typename internal::traits<Input>::Index>, 1>,
360                     const TensorReshapingOp<
361                         const DSizes<typename internal::traits<Input>::Index,
362                                      2>,
363                         const TensorImagePatchOp<
364                             Dynamic, Dynamic,
365                             const Eigen::TensorForcedEvalOp<
366                                 const Eigen::TensorShufflingOp<
367                                     const Eigen::array<
368                                         typename internal::traits<Input>::Index,
369                                         internal::traits<Input>::NumDimensions>,
370                                     const OutputBackward>>>>,
371                     const TensorReshapingOp<
372                         const DSizes<typename internal::traits<Input>::Index,
373                                      2>,
374                         const Eigen::TensorForcedEvalOp<
375                             const Eigen::TensorShufflingOp<
376                                 const Eigen::array<
377                                     typename internal::traits<Input>::Index,
378                                     internal::traits<Input>::NumDimensions>,
379                                 const Input>>>>>>>>>
380 SpatialConvolutionBackwardKernel(
381     const Input& input, const OutputBackward& output_backward,
382     typename internal::traits<Input>::Index kernelRows,
383     typename internal::traits<Input>::Index kernelCols,
384     const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
385     const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
386   typedef typename internal::traits<Input>::Index TensorIndex;
387   typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
388   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
389                    internal::traits<Input>::NumDimensions,
390                    internal::traits<Input>::Layout, TensorIndex>>
391       in(input);
392   TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
393                    internal::traits<OutputBackward>::Layout, TensorIndex>>
394       out(output_backward);
395 
396   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
397                           internal::traits<OutputBackward>::Layout,
398                       YOU_MADE_A_PROGRAMMING_MISTAKE);
399 
400   // stride and in_stride cannot both be larger than 1
401   eigen_assert(!(row_stride > 1 && row_in_stride > 1));
402   eigen_assert(!(col_stride > 1 && col_in_stride > 1));
403 
404   static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
405 
406   static const int NumDims = internal::traits<Input>::NumDimensions;
407   EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions ==
408                           internal::traits<OutputBackward>::NumDimensions,
409                       YOU_MADE_A_PROGRAMMING_MISTAKE);
410   EIGEN_STATIC_ASSERT(NumDims == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
411 
412   const TensorIndex inputRows =
413       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
414   const TensorIndex inputCols =
415       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
416 
417   const TensorIndex outputRows = isColMajor
418                                      ? output_backward.dimension(1)
419                                      : output_backward.dimension(NumDims - 2);
420   const TensorIndex outputCols = isColMajor
421                                      ? output_backward.dimension(2)
422                                      : output_backward.dimension(NumDims - 3);
423 
424   // Number of filters to apply. This is the same as the output depth of the
425   // result
426   const TensorIndex kernelFilters =
427       isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1];
428 
429   // Number of channels. This is the same as the input depth.
430   const TensorIndex kernelChannels =
431       isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1];
432 
433   // This is the effective kernel size, taking into account the
434   // (*_in_stride - 1) zero-values inserted between consecutive kernel
435   // elements in atrous convolution
436   const TensorIndex kernelRowsEff =
437       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
438   const TensorIndex kernelColsEff =
439       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
440 
441   // Number of batches (and other dimensions) in the input tensor.
442   TensorIndex batch = 1;
443   for (int d = 3; d < NumDims; ++d) {
444     batch *= isColMajor ? in.dimension(d) : in.dimension(NumDims - d - 1);
445   }
446 
447   // Computing the forward padding
448   const TensorIndex padRows = numext::maxi<Index>(
449       0, (outputRows - 1) * row_stride + kernelRowsEff - inputRows);
450   const TensorIndex padCols = numext::maxi<Index>(
451       0, (outputCols - 1) * col_stride + kernelColsEff - inputCols);
452 
453   TensorIndex padding_top = padRows / 2;
454   TensorIndex padding_left = padCols / 2;
455 
456   // Compute paddings for output_backward before extracting patches.
457   const TensorIndex expanded_out_rows = (outputRows - 1) * row_stride + 1;
458   const TensorIndex expanded_out_cols = (outputCols - 1) * col_stride + 1;
459 
460   const TensorIndex padded_out_rows = inputRows + kernelRowsEff - 1;
461   const TensorIndex padded_out_cols = inputCols + kernelColsEff - 1;
462 
463   const TensorIndex top_pad_rows = kernelRowsEff - 1 - padding_top;
464   const TensorIndex left_pad_cols = kernelColsEff - 1 - padding_left;
465 
466   const TensorIndex bottom_pad_rows =
467       padded_out_rows - expanded_out_rows - top_pad_rows;
468   const TensorIndex right_pad_cols =
469       padded_out_cols - expanded_out_cols - left_pad_cols;
470 
471   // Reorder output_backward dimensions.
472   array<TensorIndex, 4> output_backward_shuffle;
473   if (isColMajor) {
474     // From: [out_depth, out_rows, out_cols, batch]
475     // To:   [batch, out_rows, out_cols, out_depth]
476     output_backward_shuffle = {3, 1, 2, 0};
477   } else {
478     // From: [batch, out_cols, out_rows, out_depth]
479     // To:   [out_depth, out_cols, out_rows, batch]
480     output_backward_shuffle = {3, 1, 2, 0};
481   }
482 
483   // Reorder input dimensions.
484   array<TensorIndex, 4> input_shuffle;
485   if (isColMajor) {
486     // From: [in_depth, in_rows, in_cols, batch]
487     // To:   [in_depth, batch, in_rows, in_cols]
488     input_shuffle = {0, 3, 1, 2};
489   } else {
490     // From: [batch, in_cols, in_rows, in_depth]
491     // To:   [in_cols, in_rows, batch, in_depth]
492     input_shuffle = {1, 2, 0, 3};
493   }
494 
495   // Input is playing the role of a "kernel" in this convolution.
496   DSizes<TensorIndex, 2> input_dims;
497   if (isColMajor) {
498     input_dims[0] = kernelChannels;
499     input_dims[1] = batch * inputRows * inputCols;
500   } else {
501     input_dims[1] = kernelChannels;
502     input_dims[0] = inputCols * inputRows * batch;
503   }
504 
505   // Molds the output of the patch extraction result into a 2D tensor:
506   // - the first dimension (dims[0]): the patch values to be multiplied with the
507   // kernels
508   // - the second dimension (dims[1]): everything else
509   DSizes<TensorIndex, 2> pre_contract_dims;
510   if (isColMajor) {
511     pre_contract_dims[0] = batch * inputRows * inputCols;
512     pre_contract_dims[1] = kernelRows * kernelCols * kernelFilters;
513   } else {
514     pre_contract_dims[1] = inputCols * inputRows * batch;
515     pre_contract_dims[0] = kernelFilters * kernelCols * kernelRows;
516   }
517 
518   // We will contract along the collapsed dimension that contains the
519   // batch, inputRows and inputCols.
520   array<IndexPair<TensorIndex>, 1> contract_dims;
521   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
522 
523   // Dimensions after contraction.
524   DSizes<TensorIndex, NumDims> post_contract_dims;
525   if (isColMajor) {
526     post_contract_dims[0] = kernelChannels;
527     post_contract_dims[1] = kernelRows;
528     post_contract_dims[2] = kernelCols;
529     post_contract_dims[3] = kernelFilters;
530   } else {
531     post_contract_dims[0] = kernelFilters;
532     post_contract_dims[1] = kernelCols;
533     post_contract_dims[2] = kernelRows;
534     post_contract_dims[3] = kernelChannels;
535   }
536 
537   // Reorder output of contraction to a valid filter shape.
538   array<TensorIndex, 4> kernel_shuffle;
539   if (isColMajor) {
540     // From: [in_depth, kernel_rows, kernel_cols, out_depth]
541     // To:   [out_depth, in_depth, kernel_rows, kernel_cols]
542     kernel_shuffle = {3, 0, 1, 2};
543   } else {
544     // From: [out_depth, kernel_cols, kernel_rows, in_depth]
545     // To:   [kernel_cols, kernel_rows, in_depth, out_depth]
546     kernel_shuffle = {1, 2, 3, 0};
547   }
548 
549   // Reverse kernel backprop dimensions.
550   array<TensorIndex, 4> kernel_reverse;
551   if (isColMajor) {
552     kernel_reverse = {false, false, true, true};
553   } else {
554     kernel_reverse = {true, true, false, false};
555   }
556 
557   // Create convolution input (aka source of patches) from output backward
558   // tensor by shuffling dimensions.
559   const auto output_backward_shuffled =
560       output_backward.shuffle(output_backward_shuffle).eval();
561 
562   // Create convolution kernel (aka filter) from input by shuffling and
563   // reshaping.
564   const auto input_shuffled =
565       input.shuffle(input_shuffle).eval().reshape(input_dims);
566 
567   return choose(
568              Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
569              input_shuffled.contract(
570                  output_backward_shuffled
571                      .extract_image_patches(inputRows, inputCols, row_in_stride,
572                                             col_in_stride, 1, 1, row_stride,
573                                             col_stride, top_pad_rows,
574                                             bottom_pad_rows, left_pad_cols,
575                                             right_pad_cols, OutScalar(0))
576                      .reshape(pre_contract_dims),
577                  contract_dims),
578              output_backward_shuffled
579                  .extract_image_patches(
580                      inputRows, inputCols, row_in_stride, col_in_stride, 1, 1,
581                      row_stride, col_stride, top_pad_rows, bottom_pad_rows,
582                      left_pad_cols, right_pad_cols, OutScalar(0))
583                  .reshape(pre_contract_dims)
584                  .contract(input_shuffled, contract_dims))
585       .reshape(post_contract_dims)
586       .shuffle(kernel_shuffle)
587       .eval()
588       .reverse(kernel_reverse);
589 }
590 
591 }  // end namespace Eigen
592 
593 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
594