xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
17 
18 #include <algorithm>
19 
20 #include "ruy/profiler/instrumentation.h"  // from @ruy
21 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 
24 namespace tflite {
25 namespace optimized_ops {
26 
27 // Implementation of float DepthwiseConv
28 
29 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
30 struct FloatDepthwiseConvKernel {};
31 
32 #ifdef USE_NEON
33 
34 template <>
35 struct FloatDepthwiseConvKernel<false, 8, 1> {
36   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
37                   const float* input_ptr, int input_ptr_increment,
38                   const float* filter_ptr, float* acc_buffer_ptr) {
39     // Load the filters
40     float32x4_t filter[2];
41     for (int i = 0; i < 2; i++) {
42       filter[i] = vld1q_f32(filter_ptr + 4 * i);
43     }
44     int outp = 0;
45     // Handle 2 output pixels at a time.
46     for (; outp <= num_output_pixels - 2; outp += 2) {
47       // Load the inputs
48       float32x4_t input[4];
49       for (int i = 0; i < 4; i++) {
50         input[i] = vld1q_f32(input_ptr + 4 * i);
51       }
52       input_ptr += 16;
53       // Load the accumulators from acc_buffer
54       float32x4_t acc[4];
55       for (int i = 0; i < 4; i++) {
56         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
57       }
58       // Multiply-accumulate
59       acc[0] = vmlaq_f32(acc[0], input[0], filter[0]);
60       acc[1] = vmlaq_f32(acc[1], input[1], filter[1]);
61       acc[2] = vmlaq_f32(acc[2], input[2], filter[0]);
62       acc[3] = vmlaq_f32(acc[3], input[3], filter[1]);
63       // Store the accumulators back to acc_buffer
64       for (int i = 0; i < 4; i++) {
65         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
66       }
67       acc_buffer_ptr += 16;
68     }
69     // Handle one output pixel at a time.
70     for (; outp < num_output_pixels; outp++) {
71       // Load the inputs
72       float32x4_t input[2];
73       for (int i = 0; i < 2; i++) {
74         input[i] = vld1q_f32(input_ptr + 4 * i);
75       }
76       input_ptr += 8;
77       // Load the accumulators from acc_buffer
78       float32x4_t acc[2];
79       for (int i = 0; i < 2; i++) {
80         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
81       }
82       // Multiply-accumulate
83       for (int i = 0; i < 2; i++) {
84         acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
85       }
86       // Store the accumulators back to acc_buffer
87       for (int i = 0; i < 2; i++) {
88         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
89       }
90       acc_buffer_ptr += 8;
91     }
92   }
93 };
94 
95 template <>
96 struct FloatDepthwiseConvKernel<false, 2, 1> {
97   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
98                   const float* input_ptr, int input_ptr_increment,
99                   const float* filter_ptr, float* acc_buffer_ptr) {
100     const float32x2_t filters = vld1_f32(filter_ptr);
101     const float32x4_t filters_dup2 = vcombine_f32(filters, filters);
102     int outp = 0;
103     // Handle 8 output pixels at a time.
104     for (; outp <= num_output_pixels - 8; outp += 8) {
105       // Load the inputs
106       float32x4_t input[4];
107       for (int i = 0; i < 4; i++) {
108         input[i] = vld1q_f32(input_ptr + 4 * i);
109       }
110       input_ptr += 16;
111       // Load the accumulators from acc_buffer
112       float32x4_t acc[4];
113       for (int i = 0; i < 4; i++) {
114         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
115       }
116       // Multiply-accumulate
117       for (int i = 0; i < 4; i++) {
118         acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
119       }
120       // Store the accumulators back to acc_buffer
121       for (int i = 0; i < 4; i++) {
122         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
123       }
124       acc_buffer_ptr += 16;
125     }
126     // Handle 4 output pixels at a time.
127     for (; outp <= num_output_pixels - 4; outp += 4) {
128       // Load the inputs
129       float32x4_t input[2];
130       for (int i = 0; i < 2; i++) {
131         input[i] = vld1q_f32(input_ptr + 4 * i);
132       }
133       input_ptr += 8;
134       // Load the accumulators from acc_buffer
135       float32x4_t acc[2];
136       for (int i = 0; i < 2; i++) {
137         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
138       }
139       // Multiply-accumulate
140       for (int i = 0; i < 2; i++) {
141         acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
142       }
143       // Store the accumulators back to acc_buffer
144       for (int i = 0; i < 2; i++) {
145         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
146       }
147       acc_buffer_ptr += 8;
148     }
149     // Handle 2 output pixels at a time.
150     for (; outp <= num_output_pixels - 2; outp += 2) {
151       // Load the inputs
152       const float32x4_t input = vld1q_f32(input_ptr);
153       input_ptr += 4;
154       // Load the accumulators from acc_buffer
155       float32x4_t acc = vld1q_f32(acc_buffer_ptr);
156       // Multiply-accumulate
157       acc = vmlaq_f32(acc, input, filters_dup2);
158       // Store the accumulators back to acc_buffer
159       vst1q_f32(acc_buffer_ptr, acc);
160       acc_buffer_ptr += 4;
161     }
162     // Handle 1 output pixel at a time
163     for (; outp < num_output_pixels; outp++) {
164       // Load the inputs
165       const float32x2_t input = vld1_f32(input_ptr);
166       input_ptr += 2;
167       // Load the accumulators from acc_buffer
168       float32x2_t acc = vld1_f32(acc_buffer_ptr);
169       // Multiply-accumulate
170       acc = vmla_f32(acc, input, filters);
171       // Store the accumulators back to acc_buffer
172       vst1_f32(acc_buffer_ptr, acc);
173       acc_buffer_ptr += 2;
174     }
175   }
176 };
177 
178 template <>
179 struct FloatDepthwiseConvKernel<true, 0, 1> {
180   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
181                   const float* input_ptr, int input_ptr_increment,
182                   const float* filter_ptr, float* acc_buffer_ptr) {
183     // Handle one output pixel at a time.
184     for (int outp = 0; outp < num_output_pixels; outp++) {
185       const float* local_filter_ptr = filter_ptr;
186       const float* local_input_ptr = input_ptr;
187       int ic = 0;
188       // Handle 16 input channels at a time.
189       for (; ic <= input_depth - 16; ic += 16) {
190         // Load the filters
191         float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0);
192         float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1);
193         float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2);
194         float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3);
195         local_filter_ptr += 16;
196         // Load the inputs
197         float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0);
198         float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1);
199         float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2);
200         float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3);
201         local_input_ptr += 16;
202         // Load the accumulators from acc_buffer
203         float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
204         float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
205         float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
206         float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
207         // Multiply-accumulate
208         acc_0 = vmlaq_f32(acc_0, input_0, filter_0);
209         acc_1 = vmlaq_f32(acc_1, input_1, filter_1);
210         acc_2 = vmlaq_f32(acc_2, input_2, filter_2);
211         acc_3 = vmlaq_f32(acc_3, input_3, filter_3);
212         // Store the accumulators back to acc_buffer
213         vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
214         vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
215         vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
216         vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
217         acc_buffer_ptr += 16;
218       }
219       // Handle 4 input channels at a time.
220       for (; ic <= input_depth - 4; ic += 4) {
221         // Load the filters
222         float32x4_t filter;
223         filter = vld1q_f32(local_filter_ptr);
224         local_filter_ptr += 4;
225         // Load the inputs
226         float32x4_t input;
227         input = vld1q_f32(local_input_ptr);
228         local_input_ptr += 4;
229         // Load the accumulators from acc_buffer
230         float32x4_t acc;
231         acc = vld1q_f32(acc_buffer_ptr);
232         // Multiply-accumulate
233         acc = vmlaq_f32(acc, input, filter);
234         // Store the accumulators back to acc_buffer
235         vst1q_f32(acc_buffer_ptr, acc);
236         acc_buffer_ptr += 4;
237       }
238       // Handle one input channel at a time.
239       for (; ic < input_depth; ic++) {
240         const float input_val = *local_input_ptr++;
241         const float filter_val = *local_filter_ptr++;
242         *acc_buffer_ptr++ += filter_val * input_val;
243       }
244       input_ptr += input_ptr_increment;
245     }
246   }
247 };
248 
249 template <>
250 struct FloatDepthwiseConvKernel<true, 0, 8> {
251   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
252                   const float* input_ptr, int input_ptr_increment,
253                   const float* filter_ptr, float* acc_buffer_ptr) {
254     // Handle one output pixel at a time.
255     for (int outp = 0; outp < num_output_pixels; outp++) {
256       const float* local_filter_ptr = filter_ptr;
257       const float* local_input_ptr = input_ptr;
258       int ic = 0;
259       // Handle 2 input channels at a time.
260       for (; ic <= input_depth - 2; ic += 2) {
261         // Load the filters
262         float32x4_t filter[4];
263         for (int i = 0; i < 4; i++) {
264           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
265         }
266         local_filter_ptr += 16;
267         // Load the inputs
268         const float32x2_t input = vld1_f32(local_input_ptr);
269         local_input_ptr += 2;
270         // Load the accumulators from acc_buffer
271         float32x4_t acc[4];
272         for (int i = 0; i < 4; i++) {
273           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
274         }
275         // Multiply-accumulate
276         acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0);
277         acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0);
278         acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1);
279         acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1);
280         // Store the accumulators back to acc_buffer
281         for (int i = 0; i < 4; i++) {
282           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
283         }
284         acc_buffer_ptr += 16;
285       }
286       // Handle one input channel at a time.
287       for (; ic < input_depth; ic++) {
288         // Load the filters
289         float32x4_t filter[2];
290         for (int i = 0; i < 2; i++) {
291           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
292         }
293         local_filter_ptr += 8;
294         // Load the inputs
295         const float input_val = *local_input_ptr++;
296         // Load the accumulators from acc_buffer
297         float32x4_t acc[2];
298         for (int i = 0; i < 2; i++) {
299           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
300         }
301         // Multiply-accumulate
302         for (int i = 0; i < 2; i++) {
303           acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
304         }
305         // Store the accumulators back to acc_buffer
306         for (int i = 0; i < 2; i++) {
307           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
308         }
309         acc_buffer_ptr += 8;
310       }
311       input_ptr += input_ptr_increment;
312     }
313   }
314 };
315 
316 // Note this implementation is very slow for input_depths < 8
317 // (e.g. comparable to reference implementation) see, specializations for
318 // input_depth=3 below.
319 template <>
320 struct FloatDepthwiseConvKernel<true, 0, 2> {
321   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
322                   const float* input_ptr, int input_ptr_increment,
323                   const float* filter_ptr, float* acc_buffer_ptr) {
324     // Handle one output pixel at a time.
325     for (int outp = 0; outp < num_output_pixels; outp++) {
326       const float* local_filter_ptr = filter_ptr;
327       const float* local_input_ptr = input_ptr;
328       int ic = 0;
329       // Handle 8 input channels at a time.
330       for (; ic <= input_depth - 8; ic += 8) {
331         // Load the filters
332         float32x4_t filter[4];
333         for (int i = 0; i < 4; i++) {
334           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
335         }
336         local_filter_ptr += 16;
337         // Load the inputs
338         float32x4x2_t input_dup2[2];
339         for (int i = 0; i < 2; i++) {
340           const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i);
341           input_dup2[i] = vzipq_f32(input, input);
342         }
343         local_input_ptr += 8;
344         // Load the accumulators from acc_buffer
345         float32x4_t acc[4];
346         for (int i = 0; i < 4; i++) {
347           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
348         }
349         // Multiply-accumulate
350         acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]);
351         acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]);
352         acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]);
353         acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]);
354         // Store the accumulators back to acc_buffer
355         for (int i = 0; i < 4; i++) {
356           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
357         }
358         acc_buffer_ptr += 16;
359       }
360       // Handle 4 input channels at a time.
361       for (; ic <= input_depth - 4; ic += 4) {
362         // Load the filters
363         float32x2_t filter[4];
364         for (int i = 0; i < 4; i++) {
365           filter[i] = vld1_f32(local_filter_ptr + 2 * i);
366         }
367         local_filter_ptr += 8;
368         // Load the inputs
369         const float32x4_t input = vld1q_f32(local_input_ptr);
370         local_input_ptr += 4;
371         // Load the accumulators from acc_buffer
372         float32x2_t acc[4];
373         for (int i = 0; i < 4; i++) {
374           acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
375         }
376         // Multiply-accumulate
377         acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0);
378         acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1);
379         acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0);
380         acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1);
381         // Store the accumulators back to acc_buffer
382         for (int i = 0; i < 4; i++) {
383           vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
384         }
385         acc_buffer_ptr += 8;
386       }
387       // Handle 2 input channels at a time.
388       for (; ic <= input_depth - 2; ic += 2) {
389         // Load the filters
390         const float32x4_t filter = vld1q_f32(local_filter_ptr);
391         local_filter_ptr += 4;
392         // Load the inputs
393         const float32x2_t input = vld1_f32(local_input_ptr);
394         local_input_ptr += 2;
395         // Load the accumulators from acc_buffer
396         float32x2_t acc[2];
397         for (int i = 0; i < 2; i++) {
398           acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
399         }
400         // Multiply-accumulate
401         acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0);
402         acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1);
403         // Store the accumulators back to acc_buffer
404         for (int i = 0; i < 2; i++) {
405           vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
406         }
407         acc_buffer_ptr += 4;
408       }
409       // Handle one input channel at a time.
410       for (; ic < input_depth; ic++) {
411         // Load the inputs
412         const float input_val = *local_input_ptr++;
413         // Multiply-accumulate
414         for (int i = 0; i < 2; i++) {
415           acc_buffer_ptr[i] += local_filter_ptr[i] * input_val;
416         }
417         local_filter_ptr += 2;
418         acc_buffer_ptr += 2;
419       }
420       input_ptr += input_ptr_increment;
421     }
422   }
423 };
424 
425 template <>
426 struct FloatDepthwiseConvKernel<true, 3, 2> {
427   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
428                   const float* input_ptr, int input_ptr_increment,
429                   const float* filter_ptr, float* acc_buffer_ptr) {
430     // Load the filters
431     float32x2_t filter[3];
432     for (int i = 0; i < 3; i++) {
433       filter[i] = vld1_f32(filter_ptr + 2 * i);
434     }
435     // Handle one output pixel at a time.
436     for (int outp = 0; outp < num_output_pixels; outp++) {
437       const float32x2_t input01 = vld1_f32(input_ptr);
438       const float32x2_t input2 = vld1_dup_f32(input_ptr + 2);
439       // Load the accumulators from acc_buffer
440       float32x2_t acc[3];
441       for (int i = 0; i < 3; i++) {
442         acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
443       }
444       // Multiply-accumulate for each input channel there 2 outputs
445       acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0);
446       acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1);
447       acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0);
448       // Store the accumulators back to acc_buffer
449       for (int i = 0; i < 3; i++) {
450         vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
451       }
452       acc_buffer_ptr += 6;
453       input_ptr += input_ptr_increment;
454     }
455   }
456 };
457 
458 template <>
459 struct FloatDepthwiseConvKernel<true, 3, 4> {
460   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
461                   const float* input_ptr, int input_ptr_increment,
462                   const float* filter_ptr, float* acc_buffer_ptr) {
463     // Load the filters
464     float32x4_t filter[3];
465     for (int i = 0; i < 3; i++) {
466       filter[i] = vld1q_f32(filter_ptr + 4 * i);
467     }
468     // Handle one output pixel at a time.
469     for (int outp = 0; outp < num_output_pixels; outp++) {
470       // NOTE: we only want 3 values, so we read it as two ops where
471       // the second op just duplicates the lane
472       const float32x2_t input01 = vld1_f32(input_ptr);
473       const float32x2_t input2 = vld1_dup_f32(input_ptr + 2);
474       // Load the accumulators from acc_buffer
475       float32x4_t acc[3];
476       for (int i = 0; i < 3; i++) {
477         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
478       }
479       // Multiply-accumulate all outputs.
480       acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0);
481       acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1);
482       acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0);
483       // Store the accumulators back to acc_buffer
484       for (int i = 0; i < 3; i++) {
485         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
486       }
487       acc_buffer_ptr += 12;
488       input_ptr += input_ptr_increment;
489     }
490   }
491 };
492 
493 template <>
494 struct FloatDepthwiseConvKernel<true, 1, 8> {
495   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
496                   const float* input_ptr, int input_ptr_increment,
497                   const float* filter_ptr, float* acc_buffer_ptr) {
498     // Load the filters
499     float32x4_t filter[2];
500     for (int i = 0; i < 2; i++) {
501       filter[i] = vld1q_f32(filter_ptr + 4 * i);
502     }
503     // Handle one output pixel at a time.
504     for (int outp = 0; outp < num_output_pixels; outp++) {
505       // Load the inputs
506       const float input_val = *input_ptr;
507       input_ptr += input_ptr_increment;
508       // Load the accumulators from acc_buffer
509       float32x4_t acc[2];
510       for (int i = 0; i < 2; i++) {
511         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
512       }
513       // Multiply-accumulate
514       for (int i = 0; i < 2; i++) {
515         acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
516       }
517       // Store the accumulators back to acc_buffer
518       for (int i = 0; i < 2; i++) {
519         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
520       }
521       acc_buffer_ptr += 8;
522     }
523   }
524 };
525 
526 template <>
527 struct FloatDepthwiseConvKernel<true, 1, 32> {
528   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
529                   const float* input_ptr, int input_ptr_increment,
530                   const float* filter_ptr, float* acc_buffer_ptr) {
531     // Load the filters
532     float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
533     float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
534     float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
535     float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
536     float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);
537     float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5);
538     float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6);
539     float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7);
540 
541     // Handle one output pixel at a time.
542     for (int outp = 0; outp < num_output_pixels; outp++) {
543       // Load the inputs
544       const float input_val = *input_ptr;
545       input_ptr += input_ptr_increment;
546       // Load the accumulators from acc_buffer
547       float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
548       float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
549       float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
550       float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
551       float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
552       float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5);
553       float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6);
554       float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7);
555       // Multiply-accumulate
556       acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
557       acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
558       acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
559       acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
560       acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
561       acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val);
562       acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val);
563       acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val);
564       // Store the accumulators back to acc_buffer
565       vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
566       vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
567       vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
568       vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
569       vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
570       vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5);
571       vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6);
572       vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7);
573       acc_buffer_ptr += 32;
574     }
575   }
576 };
577 
578 template <>
579 struct FloatDepthwiseConvKernel<true, 1, 20> {
580   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
581                   const float* input_ptr, int input_ptr_increment,
582                   const float* filter_ptr, float* acc_buffer_ptr) {
583     // Load the filters
584     float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
585     float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
586     float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
587     float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
588     float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);
589 
590     // Handle one output pixel at a time.
591     for (int outp = 0; outp < num_output_pixels; outp++) {
592       // Load the inputs
593       const float input_val = *input_ptr;
594       input_ptr += input_ptr_increment;
595       // Load the accumulators from acc_buffer
596       float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
597       float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
598       float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
599       float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
600       float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
601       // Multiply-accumulate
602       acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
603       acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
604       acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
605       acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
606       acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
607       // Store the accumulators back to acc_buffer
608       vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
609       vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
610       vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
611       vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
612       vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
613       acc_buffer_ptr += 20;
614     }
615   }
616 };
617 
618 template <>
619 struct FloatDepthwiseConvKernel<true, 0, 16> {
620   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
621                   const float* input_ptr, int input_ptr_increment,
622                   const float* filter_ptr, float* acc_buffer_ptr) {
623     // Handle one output pixel at a time.
624     for (int outp = 0; outp < num_output_pixels; outp++) {
625       const float* local_filter_ptr = filter_ptr;
626       const float* local_input_ptr = input_ptr;
627       for (int ic = 0; ic < input_depth; ic++) {
628         // Load the filters
629         float32x4_t filter[4];
630         for (int i = 0; i < 4; i++) {
631           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
632         }
633         local_filter_ptr += 16;
634         // Load the inputs
635         const float input_val = *local_input_ptr++;
636         // Load the accumulators from acc_buffer
637         float32x4_t acc[4];
638         for (int i = 0; i < 4; i++) {
639           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
640         }
641         // Multiply-accumulate
642         for (int i = 0; i < 4; i++) {
643           acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
644         }
645         // Store the accumulators back to acc_buffer
646         for (int i = 0; i < 4; i++) {
647           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
648         }
649         acc_buffer_ptr += 16;
650       }
651       input_ptr += input_ptr_increment;
652     }
653   }
654 };
655 
656 template <>
657 struct FloatDepthwiseConvKernel<true, 8, 1> {
658   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
659                   const float* input_ptr, int input_ptr_increment,
660                   const float* filter_ptr, float* acc_buffer_ptr) {
661     // Load the filters
662     float32x4_t filter[2];
663     for (int i = 0; i < 2; i++) {
664       filter[i] = vld1q_f32(filter_ptr + 4 * i);
665     }
666     // Handle one output pixel at a time.
667     for (int outp = 0; outp < num_output_pixels; outp++) {
668       // Load the inputs
669       float32x4_t input[2];
670       for (int i = 0; i < 2; i++) {
671         input[i] = vld1q_f32(input_ptr + 4 * i);
672       }
673       // Load the accumulators from acc_buffer
674       float32x4_t acc[2];
675       for (int i = 0; i < 2; i++) {
676         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
677       }
678       // Multiply-accumulate
679       for (int i = 0; i < 2; i++) {
680         acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
681       }
682       // Store the accumulators back to acc_buffer
683       for (int i = 0; i < 2; i++) {
684         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
685       }
686       acc_buffer_ptr += 8;
687       input_ptr += input_ptr_increment;
688     }
689   }
690 };
691 
692 template <>
693 struct FloatDepthwiseConvKernel<true, 2, 1> {
694   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
695                   const float* input_ptr, int input_ptr_increment,
696                   const float* filter_ptr, float* acc_buffer_ptr) {
697     float32x2_t filter = vld1_f32(filter_ptr);
698     float32x4_t filter_x4 = vcombine_f32(filter, filter);
699     int outp = 0;
700 
701     // Handle two output pixels at a time.
702     for (; outp <= num_output_pixels - 2; outp += 2) {
703       // Load the inputs
704       float32x2_t input_1 = vld1_f32(input_ptr);
705       input_ptr += input_ptr_increment;
706       float32x2_t input_2 = vld1_f32(input_ptr);
707       input_ptr += input_ptr_increment;
708       float32x4_t input = vcombine_f32(input_1, input_2);
709 
710       // Load the accumulators from acc_buffer
711       float32x4_t acc = vld1q_f32(acc_buffer_ptr);
712 
713       // Multiply-accumulate
714       acc = vmlaq_f32(acc, input, filter_x4);
715 
716       // Store the accumulators back to acc_buffer
717       vst1q_f32(acc_buffer_ptr, acc);
718       acc_buffer_ptr += 4;
719     }
720     // Handle one output pixel at a time.
721     for (; outp < num_output_pixels; outp++) {
722       // Load the inputs
723       float32x2_t input = vld1_f32(input_ptr);
724       input_ptr += input_ptr_increment;
725 
726       // Load the accumulators from acc_buffer
727       float32x2_t acc = vld1_f32(acc_buffer_ptr);
728 
729       // Multiply-accumulate
730       acc = vmla_f32(acc, input, filter);
731 
732       // Store the accumulators back to acc_buffer
733       vst1_f32(acc_buffer_ptr, acc);
734       acc_buffer_ptr += 2;
735     }
736   }
737 };
738 
739 template <>
740 struct FloatDepthwiseConvKernel<true, 4, 1> {
741   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
742                   const float* input_ptr, int input_ptr_increment,
743                   const float* filter_ptr, float* acc_buffer_ptr) {
744     float32x4_t filter = vld1q_f32(filter_ptr);
745 
746     // Handle one output pixel at a time.
747     for (int outp = 0; outp < num_output_pixels; outp++) {
748       // Load the inputs
749       float32x4_t input = vld1q_f32(input_ptr);
750       // Load the accumulators from acc_buffer
751       float32x4_t acc = vld1q_f32(acc_buffer_ptr);
752       // Multiply-accumulate
753       acc = vmlaq_f32(acc, input, filter);
754       // Store the accumulators back to acc_buffer
755       vst1q_f32(acc_buffer_ptr, acc);
756       acc_buffer_ptr += 4;
757       input_ptr += input_ptr_increment;
758     }
759   }
760 };
761 #endif
762 
763 // Accumulates the effect of one row of the filter, on a segment of one row
764 // of the output, accessing the corresponding one row of the input.
765 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
766 void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
767                                 int input_depth, int input_width,
768                                 const float* input_data, int pad_width,
769                                 int depth_multiplier, int filter_width,
770                                 const float* filter_data,
771                                 int out_x_buffer_start, int out_x_buffer_end,
772                                 int output_depth, float* acc_buffer) {
773   ruy::profiler::ScopeLabel label(TFLITE_PRETTY_FUNCTION);
774   // Consistency check parameters. This is important in particular to ensure
775   // that we keep the number of template instantiations minimal, so we don't
776   // increase binary size unnecessarily.
777   static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
778   static_assert(kFixedInputDepth || kAllowStrided, "");
779   TFLITE_DCHECK(stride == 1 || kAllowStrided);
780   if (kFixedInputDepth) {
781     TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
782   }
783   if (kFixedDepthMultiplier) {
784     TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
785   }
786   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
787   const int input_ptr_increment = stride * input_depth;
788   const float* filter_base_ptr = filter_data;
789   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
790     // For the current (filter_x, filter_y) point in the filter,
791     // compute the boundaries of the corresponding output row segment.
792     int out_x_loop_start_unclamped = 0;
793     int out_x_loop_end_unclamped = 0;
794     if (kAllowStrided) {
795       if (stride == 2) {
796         out_x_loop_start_unclamped =
797             (pad_width - dilation_factor * filter_x + 1) / 2;
798         out_x_loop_end_unclamped =
799             (pad_width + input_width - dilation_factor * filter_x + 1) / 2;
800       } else if (stride == 4) {
801         out_x_loop_start_unclamped =
802             (pad_width - dilation_factor * filter_x + 3) / 4;
803         out_x_loop_end_unclamped =
804             (pad_width + input_width - dilation_factor * filter_x + 3) / 4;
805       } else {
806         out_x_loop_start_unclamped =
807             (pad_width - dilation_factor * filter_x + stride - 1) / stride;
808         out_x_loop_end_unclamped = (pad_width + input_width -
809                                     dilation_factor * filter_x + stride - 1) /
810                                    stride;
811       }
812     } else {
813       out_x_loop_start_unclamped = pad_width - dilation_factor * filter_x;
814       out_x_loop_end_unclamped =
815           pad_width + input_width - dilation_factor * filter_x;
816     }
817     // The kernel will have to iterate on the segment of the
818     // output row that starts at out_x_loop_start and out_x_loop_end.
819     const int out_x_loop_start =
820         std::max(out_x_buffer_start, out_x_loop_start_unclamped);
821     const int out_x_loop_end =
822         std::min(out_x_buffer_end, out_x_loop_end_unclamped);
823 
824     float* acc_buffer_ptr =
825         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
826     const int in_x_origin =
827         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
828     const float* input_ptr = input_data + in_x_origin * input_depth;
829     const int num_output_pixels = out_x_loop_end - out_x_loop_start;
830     FloatDepthwiseConvKernel<kAllowStrided, kFixedInputDepth,
831                              kFixedDepthMultiplier>::Run(num_output_pixels,
832                                                          input_depth,
833                                                          depth_multiplier,
834                                                          input_ptr,
835                                                          input_ptr_increment,
836                                                          filter_base_ptr,
837                                                          acc_buffer_ptr);
838     filter_base_ptr += output_depth;
839   }
840 }
841 
842 // generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
843 inline void FloatDepthwiseConvAccumRowGeneric(
844     int stride, int dilation_factor, int input_depth, int input_width,
845     const float* input_data, int pad_width, int depth_multiplier,
846     int filter_width, const float* filter_data, int out_x_buffer_start,
847     int out_x_buffer_end, int output_depth, float* acc_buffer) {
848   ruy::profiler::ScopeLabel label("DepthwiseConvAccumRowGeneric (slow)");
849   const float* filter_base_ptr = filter_data;
850   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
851     const int out_x_loop_start = std::max(
852         out_x_buffer_start,
853         (pad_width - dilation_factor * filter_x + stride - 1) / stride);
854     const int out_x_loop_end = std::min(
855         out_x_buffer_end,
856         (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
857             stride);
858 
859     float* acc_buffer_ptr =
860         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
861     const int in_x_origin =
862         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
863     const float* input_ptr = input_data + in_x_origin * input_depth;
864     const int input_ptr_increment = (stride - 1) * input_depth;
865     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
866       const float* filter_ptr = filter_base_ptr;
867       for (int ic = 0; ic < input_depth; ++ic) {
868         const float input_val = *input_ptr++;
869         for (int m = 0; m < depth_multiplier; m++) {
870           const float filter_val = *filter_ptr++;
871           *acc_buffer_ptr++ += filter_val * input_val;
872         }
873       }
874       input_ptr += input_ptr_increment;
875     }
876     filter_base_ptr += output_depth;
877   }
878 }
879 
880 // Initializes the accumulator buffer with bias values.
881 inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
882                                        const float* bias_data,
883                                        float* acc_buffer) {
884   // TODO(benoitjacob): This might need optimized specializations
885   // for small output_depth values, if that ever becomes an important
886   // case (like it was for some quantized DepthwiseConv cases).
887   for (int i = 0; i < num_output_pixels; i++) {
888     memcpy(acc_buffer + i * output_depth, bias_data,
889            sizeof(acc_buffer[0]) * output_depth);
890   }
891 }
892 
893 // DepthwiseConv can run with multi threads on the dim specified by thread_dim.
894 // Each thread processes output elements on dim, thread_dim, in the range of
895 // [thread_start, thread_end).
896 // For example, assume thread_start = 2, thread_end = 6, and thread_dim = 1, it
897 // means that it will calculate DepthwiseConv for output_data[:, 2:5, :, :].
898 //
899 // The cpu_flags is currently unused. This
900 // parameter is included so that the signature matches that required by a
901 // templated function. Other versions, such as quantized, need this parameter.
902 inline void DepthwiseConvImpl(
903     const DepthwiseParams& params, const RuntimeShape& input_shape,
904     const float* input_data, const RuntimeShape& filter_shape,
905     const float* filter_data, const RuntimeShape& bias_shape,
906     const float* bias_data, const RuntimeShape& output_shape,
907     float* output_data, const CpuFlags& /* cpu_flags */, int thread_start,
908     int thread_end, int thread_dim) {
909   ruy::profiler::ScopeLabel label("DepthwiseConv/float/DepthwiseConvImpl");
910 
911   const int stride_width = params.stride_width;
912   const int stride_height = params.stride_height;
913   const int pad_width = params.padding_values.width;
914   const int pad_height = params.padding_values.height;
915   const int depth_multiplier = params.depth_multiplier;
916   const float output_activation_min = params.float_activation_min;
917   const float output_activation_max = params.float_activation_max;
918   const int dilation_width_factor = params.dilation_width_factor;
919   const int dilation_height_factor = params.dilation_height_factor;
920   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
921   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
922   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
923   TFLITE_DCHECK(thread_dim == 0 || thread_dim == 1);
924 
925   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
926   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
927   const int input_height = input_shape.Dims(1);
928   const int input_width = input_shape.Dims(2);
929   const int input_depth = input_shape.Dims(3);
930   const int filter_height = filter_shape.Dims(1);
931   const int filter_width = filter_shape.Dims(2);
932   const int output_height = output_shape.Dims(1);
933   const int output_width = output_shape.Dims(2);
934   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
935   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
936 
937   static const int kAccBufferMaxSize = 4832;
938   float acc_buffer[kAccBufferMaxSize];
939   TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth);
940   const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
941   const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
942   TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
943                    kAccBufferActualSize);
944   TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize);
945   TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
946 
947   // row_accum_func will point to the core accumulation function to be used
948   // for this DepthwiseConv op.
949   using row_accum_func_t = decltype(&FloatDepthwiseConvAccumRowGeneric);
950   row_accum_func_t row_accum_func = nullptr;
951 
952 #define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
953                                         FIXED_DEPTH_MULTIPLIER)           \
954   if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) &&          \
955       (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) &&     \
956       depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                       \
957     row_accum_func =                                                      \
958         FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,      \
959                                    FIXED_DEPTH_MULTIPLIER>;               \
960   }
961 
962 #ifdef USE_NEON
963   // We go over our list of kernels by decreasing order of preference
964   // for the cases where multiple kernels could apply.
965 
966   // Start with the fastest kernels: AllowStrided=false, fixed input depth.
967 
968   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
969   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
970 
971   // Next come the strided kernels: AllowStrided=true, fixed input depth.
972   // They are a bit less efficient, but allow stride!=1.
973 
974   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
975   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
976   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20)
977   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
978   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
979   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2)
980   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 4)
981   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
982 
983   // Finally, the kernels allowing a variable input depth,
984   // these are the least efficient but most general kernels.
985 
986   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
987   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
988   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8)
989   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16)
990 
991 #endif  // USE_NEON
992 
993 #undef TFMINI_USE_DEPTHWISECONV_KERNEL
994 
995   // No matching fast kernel found, use slow fallback.
996   if (!row_accum_func) {
997     row_accum_func = FloatDepthwiseConvAccumRowGeneric;
998   }
999 
1000   const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
1001   const int input_batch_stride = input_height_stride * input_shape.Dims(1);
1002   const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
1003 
1004   // Now that we have determined row_accum_func, we can start work.
1005   int batch_start = 0;
1006   int batch_end = batches;
1007   int row_start = 0;
1008   int row_end = output_height;
1009   int output_ptr_offset = 0;
1010 
1011   switch (thread_dim) {
1012     case 0:
1013       // Multithread along with the batch axis
1014       TFLITE_DCHECK_GE(thread_start, 0);
1015       TFLITE_DCHECK_LE(thread_end, batches);
1016       batch_start = thread_start;
1017       batch_end = thread_end;
1018       output_ptr_offset = batch_start * FlatSizeSkipDim(output_shape, 0);
1019       break;
1020     case 1:
1021       // Multithread along with the row axis
1022       TFLITE_DCHECK_GE(thread_start, 0);
1023       TFLITE_DCHECK_LE(thread_end, output_height);
1024       row_start = thread_start;
1025       row_end = thread_end;
1026       output_ptr_offset = row_start * output_width * output_depth;
1027       break;
1028   }
1029 
1030   float* output_ptr = output_data + output_ptr_offset;
1031   int batch_step =
1032       (output_height + row_start - row_end) * output_width * output_depth;
1033 
1034   for (int b = batch_start; b < batch_end; ++b) {
1035     for (int out_y = row_start; out_y < row_end; ++out_y) {
1036       const int in_y_origin = (out_y * stride_height) - pad_height;
1037       const int filter_y_start =
1038           std::max(0, (-in_y_origin + dilation_height_factor - 1) /
1039                           dilation_height_factor);
1040       const int filter_y_end =
1041           std::min(filter_height,
1042                    (input_height - in_y_origin + dilation_height_factor - 1) /
1043                        dilation_height_factor);
1044       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
1045            out_x_buffer_start += kOutputPixelsInAccBuffer) {
1046         const int out_x_buffer_end = std::min(
1047             output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
1048         // We call a 'pixel' a group of activation that share all but the
1049         // 'depth'/'channel' coordinate. num_output_pixels is the number of
1050         // output pixels that we will accumulate in this loop iteration.
1051         const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
1052         // Initialize our local accumulator with the bias values, so we don't
1053         // have to add them later.
1054         DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
1055                                    acc_buffer);
1056         // Accumulation loop. Most of the time should be spent in here.
1057         for (int filter_y = filter_y_start; filter_y < filter_y_end;
1058              ++filter_y) {
1059           const int in_y = in_y_origin + dilation_height_factor * filter_y;
1060           row_accum_func(
1061               stride_width, dilation_width_factor, input_depth, input_width,
1062               input_data + in_y * input_height_stride + b * input_batch_stride,
1063               pad_width, depth_multiplier, filter_width,
1064               filter_data + filter_y * filter_height_stride, out_x_buffer_start,
1065               out_x_buffer_end, output_depth, acc_buffer);
1066         }
1067         // Finished accumulating. Now store to destination.
1068         const int num_output_values = output_depth * num_output_pixels;
1069         int i = 0;
1070 // TODO(benoitjacob) optimized code goes here
1071 #ifdef USE_NEON
1072         // Handle 16 values at a time
1073         for (; i <= num_output_values - 16; i += 16) {
1074           float32x4_t acc[4];
1075           for (int k = 0; k < 4; k++) {
1076             acc[k] = vld1q_f32(acc_buffer + i + 4 * k);
1077           }
1078           for (int k = 0; k < 4; k++) {
1079             acc[k] = vmaxq_f32(
1080                 vdupq_n_f32(output_activation_min),
1081                 vminq_f32(vdupq_n_f32(output_activation_max), acc[k]));
1082           }
1083           for (int k = 0; k < 4; k++) {
1084             vst1q_f32(output_ptr + 4 * k, acc[k]);
1085           }
1086           output_ptr += 16;
1087         }
1088         // Handle 4 values at a time
1089         for (; i <= num_output_values - 4; i += 4) {
1090           float32x4_t acc = vld1q_f32(acc_buffer + i);
1091 
1092           acc = vmaxq_f32(vdupq_n_f32(output_activation_min),
1093                           vminq_f32(vdupq_n_f32(output_activation_max), acc));
1094 
1095           vst1q_f32(output_ptr, acc);
1096           output_ptr += 4;
1097         }
1098 #endif
1099         // Handle leftover values, one by one. This is very slow.
1100         for (; i < num_output_values; i++) {
1101           float acc = acc_buffer[i];
1102           acc = std::max(output_activation_min,
1103                          std::min(output_activation_max, acc));
1104 
1105           *output_ptr++ = acc;
1106         }
1107       }
1108     }
1109     output_ptr += batch_step;
1110   }
1111 }
1112 
1113 
1114 }  // namespace optimized_ops
1115 }  // namespace tflite
1116 
1117 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
1118