xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesConvolution.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/core/dispatch/Dispatcher.h>
10 
11 namespace at::functorch {
12 
13 // convolution_batch_rule translated from jax with modifications:
14 // https://github.com/google/jax/blob/master/jax/_src/lax/lax.py#L3143
15 
16 // PyTorch's convolution is different from JAX's conv_general_dilated:
17 // we do not support batch_group_count (which is needed for convolution backwards).
18 // Instead, there's a convolution_backward op that needs a batching rule.
19 static std::tuple<Tensor, std::optional<int64_t>>
convolution_batch_rule(const Tensor & lhs,std::optional<int64_t> lhs_bdim,const Tensor & rhs,std::optional<int64_t> rhs_bdim,const std::optional<Tensor> & bias,std::optional<int64_t> bias_bdim,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,c10::SymInt groups)20 convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const Tensor& rhs, std::optional<int64_t> rhs_bdim, const std::optional<Tensor>& bias, std::optional<int64_t> bias_bdim, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
21   DimVector lhs_spec(stride.size() + 2);
22   std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
23   DimVector rhs_spec = lhs_spec;
24   DimVector out_spec = lhs_spec;
25   if (transposed) {
26     rhs_spec[0] = 1;
27     rhs_spec[1] = 0;
28   }
29 
30   // If we have a batched bias or weight, we need to perform the computation separately.
31   std::optional<Tensor> unbatched_bias;
32   bool separate_bias = false;
33   if ((rhs_bdim && bias && bias->defined()) || bias_bdim) {
34     TORCH_INTERNAL_ASSERT(bias.has_value());
35     TORCH_INTERNAL_ASSERT(bias->defined());
36     unbatched_bias = std::nullopt;
37     separate_bias = true;
38   } else {
39     unbatched_bias = bias;
40     separate_bias = false;
41   }
42   std::tuple<Tensor, std::optional<int64_t>> result;
43   if (lhs_bdim && !rhs_bdim) {
44     auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs);
45     auto out = at::convolution_symint(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
46     out = reshape_dim_outof_symint(out_spec[0], lhs.sizes()[*lhs_bdim], out);
47     result = std::make_tuple(out, out_spec[0]);
48   } else if (!lhs_bdim && rhs_bdim) {
49     if (groups == 1) {
50       auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs);
51       auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
52       out = reshape_dim_outof_symint(out_spec[1], rhs.size(*rhs_bdim), out);
53       result = std::make_tuple(out, out_spec[1]);
54     } else {
55       if (transposed) {
56         // conv_transpose with groups is normally NIHW, IOHW -> N(GO)HW
57         // With RHS batched, we do the following:
58         // NIHW, BIOHW -> NIHW, I(BO)HW -> N(GBO)HW -> BN(GO)HW
59         // NB: the following isn't written using rhs_spec
60         // (PyTorch convs have a fixed dimension order)
61 
62         // BIOHW -> I(BO)HW
63         auto new_w = reshape_dim_into(*rhs_bdim, 1, rhs);
64         // NIHW, I(BO)HW -> N(GBO)HW
65         auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
66         // N(GBO)HW -> NG(BO)HW
67         out = reshape_dim_outof_symint(1, groups, out);
68         // NG(BO)HW -> NGBOHW
69         out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
70         // NGBOHW -> NB(GO)HW
71         out = reshape_dim_into(1, 2, out);
72         result = std::make_tuple(out, 1);
73       } else {
74         // conv with groups is normally N(GI)HW, (GO)IHW -> N(GO)HW
75         // With RHS batched, we do the following:
76         // N(GI)HW, B(GO)IHW -> N(GI)HW, (GBO)IHW -> N(GBO)HW -> BN(GO)HW
77         // NB: the following isn't written using rhs_spec
78         // (PyTorch convs have a fixed dimension order)
79 
80         // B(GO)IHW -> BGOIHW
81         auto new_w = reshape_dim_outof_symint(0 + (*rhs_bdim == 0), groups, rhs);
82         // BGOIHW -> G(BO)IHW
83         new_w = reshape_dim_into(*rhs_bdim + (*rhs_bdim > 0), 1, new_w);
84         // G(BO)IHW -> (GBO)IHW
85         new_w = reshape_dim_into(0, 0, new_w);
86         // N(GI)HW, (GBO)IHW -> N(GBO)HW
87         auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
88         // N(GBO)HW -> NG(BO)HW
89         out = reshape_dim_outof_symint(1, groups, out);
90         // NG(BO)HW -> NGBOHW
91         out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
92         // NGBOHW -> NB(GO)HW
93         out = reshape_dim_into(1, 2, out);
94         result = std::make_tuple(out, 1);
95       }
96     }
97   } else if (lhs_bdim && rhs_bdim) {
98     auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs);
99     groups *= lhs.sizes()[*lhs_bdim];
100     auto dim_with_groups = transposed ? 1 : 0;
101     auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs);
102     auto out = at::convolution_symint(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
103     out = reshape_dim_outof_symint(out_spec[1], lhs.sizes()[*lhs_bdim], out);
104     result = std::make_tuple(out, out_spec[1]);
105   } else {
106     result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), std::nullopt);
107   }
108   if (separate_bias) {
109     auto A = std::get<0>(result);
110     auto A_batch_dim = std::get<1>(result);
111     auto B = *bias;
112     auto B_batch_dim = bias_bdim;
113     A = moveBatchDimToFront(A, A_batch_dim);
114     B = moveBatchDimToFront(B, B_batch_dim);
115     for (size_t i = 0; i < out_spec.size() - 2; i++) {
116       B = B.unsqueeze(-1);
117     }
118     B = maybePadToLogicalRank(B, B_batch_dim, rankWithoutBatchDim(A, A_batch_dim));
119 
120     return std::make_tuple(at::add(A, B), 0);
121   } else {
122     return result;
123   }
124 }
125 
_convolution_decomp(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,bool benchmark,bool deterministic,bool cudnn_enabled,bool allow_tf32)126 static Tensor _convolution_decomp(
127     const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
128     IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
129     bool transposed_, IntArrayRef output_padding_, int64_t groups_,
130     bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
131   // Ignore everything. If the user called this in the normal way,
132   // then they should be fine.
133   (void) benchmark;
134   (void) deterministic;
135   (void) cudnn_enabled;
136   (void) allow_tf32;
137   return at::convolution(
138       input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
139 }
140 
compute_grad_bias(const Tensor & grad_output_,std::array<bool,3> output_mask)141 static Tensor compute_grad_bias(
142     const Tensor& grad_output_, std::array<bool, 3> output_mask) {
143   if (!output_mask[2]) {
144     return Tensor();
145   }
146   DimVector reduce_dims;
147   reduce_dims.resize(grad_output_.dim() - 1);
148   reduce_dims[0] = 0;
149   std::iota(reduce_dims.begin() + 1, reduce_dims.end(), 2);
150   return grad_output_.sum(reduce_dims);
151 }
152 
153 // reshapes the batch_size into dim
make_dummy(const Tensor & tensor,std::optional<int64_t> tensor_bdim,int64_t dim,int64_t batch_size)154 static Tensor make_dummy(
155     const Tensor& tensor, std::optional<int64_t> tensor_bdim,
156     int64_t dim, int64_t batch_size) {
157   auto tensor_ = tensor_bdim ? tensor.select(*tensor_bdim, 0) : tensor;
158   auto orig_size = tensor_.size(dim);
159   tensor_ = tensor_.slice(dim, 0, 1);
160 
161   DimVector expand_shape(tensor_.sizes().begin(), tensor_.sizes().end());
162   expand_shape[dim] = batch_size * orig_size;
163 
164   return tensor_.new_empty({}).expand(expand_shape);
165 }
166 
167 static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_input_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & weight,std::optional<int64_t> weight_bdim,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,const c10::SymInt & groups)168 convolution_backward_input_batch_rule(
169     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
170     const Tensor& input, std::optional<int64_t> input_bdim,
171     const Tensor& weight, std::optional<int64_t> weight_bdim,
172     c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
173     c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
174   const std::array<bool, 3> mask = {true, false, false};
175   if (grad_output_bdim && weight_bdim) {
176     // regular: BNO, BOI -> N(BO), (BO)I -> N(BI)
177     // transposed: BNO, BIO -> N(BO), (BI)O -> N(BI)
178     const auto batch_size = weight.size(*weight_bdim);
179     const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
180     const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
181     auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
182     const auto result = at::convolution_backward_symint(
183         grad_output_, dummy_input, weight_, std::nullopt, stride, padding,
184         dilation, transposed, output_padding, groups * batch_size, mask);
185     const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
186     return std::make_tuple(grad_input, 1);
187   } else if (grad_output_bdim && !weight_bdim) {
188     // BNO, OI -> (BN)O, OI -> (BN)I
189     // transposed is the same.
190     const auto batch_size = grad_output.size(*grad_output_bdim);
191     const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
192     auto dummy_input = make_dummy(input, input_bdim, 0, batch_size);
193     const auto result = at::convolution_backward_symint(
194         grad_output_, dummy_input, weight, std::nullopt, stride, padding,
195         dilation, transposed, output_padding, groups, mask);
196     const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
197     return std::make_tuple(grad_input, 0);
198   } else if (!grad_output_bdim && weight_bdim) {
199     const auto batch_size = weight.size(*weight_bdim);
200     if (groups == 1) {
201       // regular: NO, BOI -> NO, O(BI) -> N(BI)
202       // transposed: NO, BIO -> NO, (BI)O -> N(BI)
203       const auto in_ch_dim = transposed ? 0 : 1;
204       const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight);
205       auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
206       const auto result = at::convolution_backward_symint(
207           grad_output, dummy_input, weight_, std::nullopt, stride, padding,
208           dilation, transposed, output_padding, groups, mask);
209       const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
210       return std::make_tuple(grad_input, 1);
211     }
212     Tensor grad_input;
213     if (!transposed) {
214       // N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
215       const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
216       auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
217       const auto result = at::convolution_backward_symint(
218           grad_output, dummy_input, weight_, std::nullopt, stride, padding,
219           dilation, transposed, output_padding, groups, mask);
220       grad_input = std::get<0>(result); // N(GBI)
221     } else {
222       // N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
223       auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O
224       weight_ = reshape_dim_outof_symint(1, groups, weight_);         // BGIO
225       weight_ = weight_.transpose(0, 1);                       // GBIO
226       weight_ = weight_.flatten(0, 2);                         // (GBI)O
227       const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
228       const auto result = at::convolution_backward_symint(
229           grad_output, dummy_input, weight_, std::nullopt, stride, padding,
230           dilation, transposed, output_padding, groups, mask);
231       grad_input = std::get<0>(result); // N(GBI)
232     }
233     // N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
234     grad_input = reshape_dim_outof_symint(1, groups, grad_input);
235     grad_input = reshape_dim_outof_symint(2, batch_size, grad_input);
236     grad_input = grad_input.transpose(1, 2);
237     grad_input = reshape_dim_into(2, 2, grad_input);
238     return std::make_tuple(grad_input, 1);
239   } else {
240     TORCH_INTERNAL_ASSERT(input_bdim);
241     const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
242     const auto result = at::convolution_backward_symint(
243         grad_output, dummy_input, weight, std::nullopt, stride, padding,
244         dilation, transposed, output_padding, groups, mask);
245     return std::make_tuple(std::get<0>(result), std::nullopt);
246   }
247 }
248 static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_weight_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & weight,std::optional<int64_t> weight_bdim,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,const c10::SymInt & groups)249 convolution_backward_weight_batch_rule(
250     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
251     const Tensor& input, std::optional<int64_t> input_bdim,
252     const Tensor& weight, std::optional<int64_t> weight_bdim,
253     c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
254     c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
255   const std::array<bool, 3> mask = {false, true, false};
256   if (grad_output_bdim && input_bdim) {
257     // BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed)
258     const auto batch_size = input.size(*input_bdim);
259     const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
260     const auto input_ = reshape_dim_into(*input_bdim, 1, input);
261     const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
262     const auto result = at::convolution_backward_symint(
263         grad_output_, input_, dummy_weight, std::nullopt, stride, padding,
264         dilation, transposed, output_padding, groups * batch_size, mask);
265     auto grad_weight = std::get<1>(result);
266     grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight);
267     return std::make_tuple(grad_weight, 0);
268   } else if (grad_output_bdim && !input_bdim) {
269     const auto batch_size = grad_output.size(*grad_output_bdim);
270     if (groups == 1) {
271       // regular: BNO, NI -> N(BO), NI -> (BO)I
272       // transposed: BNO, NI -> N(BO), NI -> I(BO)
273       const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
274       const auto out_ch_dim = transposed ? 1 : 0;
275       const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
276       const auto result = at::convolution_backward_symint(
277           grad_output_, input, dummy_weight, std::nullopt, stride, padding,
278           dilation, transposed, output_padding, groups, mask);
279       auto grad_weight = std::get<1>(result);
280       grad_weight = reshape_dim_outof_symint(out_ch_dim, batch_size, grad_weight);
281       return std::make_tuple(grad_weight, out_ch_dim);
282     } else {
283       auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); // BN(GO)
284       grad_output_ = reshape_dim_outof_symint(2, groups, grad_output_);              // BNGO
285       grad_output_ = grad_output_.movedim(0, 2);                              // NGBO
286       grad_output_ = grad_output_.flatten(1, 3);                              // N(GBO)
287       if (!transposed) {
288         // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
289         const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
290         const auto result = at::convolution_backward_symint(
291             grad_output_, input, dummy_weight, std::nullopt, stride, padding,
292             dilation, transposed, output_padding, groups, mask);
293         auto grad_weight = std::get<1>(result);
294         grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBOI
295         grad_weight = grad_weight.transpose(0, 1);                          // BGOI
296         grad_weight = grad_weight.flatten(1, 2);                            // B(GO)I
297         return std::make_tuple(grad_weight, 0);
298       } else {
299         // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
300         const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
301         const auto result = at::convolution_backward_symint(
302             grad_output_, input, dummy_weight, std::nullopt, stride, padding,
303             dilation, transposed, output_padding, groups, mask);
304         auto grad_weight = std::get<1>(result);
305         grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
306         return std::make_tuple(grad_weight, 1);
307       }
308     }
309   } else if (!grad_output_bdim && input_bdim) {
310     const auto batch_size = input.size(*input_bdim);
311     if (groups == 1) {
312       // regular: NO, BNI -> NO, N(BI) -> O(BI)
313       // transposed: NO, BNI -> NO, N(BI) -> (BI)O
314       const auto input_ = reshape_dim_into(*input_bdim, 1, input);
315       const auto in_ch_dim = transposed ? 0 : 1;
316       const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
317       const auto result = at::convolution_backward_symint(
318           grad_output, input_, dummy_weight, std::nullopt, stride, padding,
319           dilation, transposed, output_padding, groups, mask);
320       auto grad_weight = std::get<1>(result);
321       grad_weight = reshape_dim_outof_symint(in_ch_dim, batch_size, grad_weight);
322       return std::make_tuple(grad_weight, in_ch_dim);
323     } else {
324       auto input_ = moveBatchDimToFront(input, input_bdim); // BN(GI)
325       input_ = reshape_dim_outof_symint(2, groups, input_);        // BNGI
326       input_ = input_.movedim(0, 2);                        // NGBI
327       input_ = input_.flatten(1, 3);                        // N(GBI)
328       if (!transposed) {
329         // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
330         const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
331         const auto result = at::convolution_backward_symint(
332             grad_output, input_, dummy_weight, std::nullopt, stride, padding,
333             dilation, transposed, output_padding, groups, mask);
334         auto grad_weight = std::get<1>(result);
335         grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
336         return std::make_tuple(grad_weight, 1);
337       } else {
338         // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
339         const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
340         const auto result = at::convolution_backward_symint(
341             grad_output, input_, dummy_weight, std::nullopt, stride, padding,
342             dilation, transposed, output_padding, groups, mask);
343         auto grad_weight = std::get<1>(result);
344         grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBIO
345         grad_weight = grad_weight.transpose(0, 1);                          // BGIO
346         grad_weight = grad_weight.flatten(1, 2);                            // B(GI)O
347         return std::make_tuple(grad_weight, 0);
348       }
349     }
350   } else {
351     TORCH_INTERNAL_ASSERT(weight_bdim);
352     const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
353     const auto result = at::convolution_backward_symint(
354         grad_output, input, dummy_weight, std::nullopt, stride, padding,
355         dilation, transposed, output_padding, groups, mask);
356     return std::make_tuple(std::get<1>(result), std::nullopt);
357 
358   }
359 }
360 
convolution_backward_plumbing(const Tensor & grad_output_,const Tensor & input_,const Tensor & weight_,const c10::OptionalArrayRef<SymInt> bias_sizes_opt,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,c10::SymInt groups,std::array<bool,3> output_mask)361 static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
362     const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
363     const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
364     c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
365     c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array<bool, 3> output_mask) {
366   const auto maybe_layer = maybeCurrentDynamicLayer();
367   vmap_check_escaped(maybe_layer, "convolution_backward_plumbing");
368   int64_t cur_level = maybe_layer->layerId();
369 
370   if (!areAnyBatchedAtLevel({grad_output_, input_, weight_}, cur_level)){
371     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
372     return at::convolution_backward_symint(
373         grad_output_, input_, weight_, bias_sizes_opt, stride, padding,
374         dilation, transposed, output_padding, groups, output_mask);
375   }
376 
377   auto [grad_output, grad_output_bdim] = unwrapTensorAtLevel(grad_output_, cur_level);
378   auto [input, input_bdim] = unwrapTensorAtLevel(input_, cur_level);
379   auto [weight, weight_bdim] = unwrapTensorAtLevel(weight_, cur_level);
380 
381   const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
382   output_mask[2] = false;
383 
384   // TODO: A little bird says that unfold + matmul is actually faster than
385   // group convolution in many cases. We should benchmark some of
386   // the common cases and replace things with unfold + matmul as necessary.
387 
388   // Notation:
389   // B - a batch dimension
390   // G - groups (sometimes omitted because it doesn't matter)
391   // NO - grad_output
392   // NI - input
393   // OI - weight
394   // "(BO)I" - we don't actually care about the values of this Tensor,
395   //           we just need to create a tensor on the same device with the
396   //           correct shape and pray that the implementation is smart enough
397   //           to not do anything with it.
398 
399   // BNO, BNI, BOI
400   // AKA one of the model ensembling case
401   if (grad_output_bdim && input_bdim && weight_bdim) {
402     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
403     grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output);
404 
405     // BNO, BNI, BOI -> N(BO), N(BI), (BO)I
406     const auto batch_size = weight.size(*weight_bdim);
407     input = reshape_dim_into(*input_bdim, 1, input);
408     weight = reshape_dim_into(*weight_bdim, 0, weight);
409     const auto result = at::convolution_backward_symint(
410         grad_output, input, weight, std::nullopt, stride, padding, dilation,
411         transposed, output_padding, batch_size * groups, output_mask);
412     // N(BI), (BO)I -> NBI, BOI
413     const auto grad_input = output_mask[0] ?
414       reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
415     const auto grad_weight = output_mask[1] ?
416       reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
417     return std::make_tuple(
418         output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
419         output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
420         grad_bias);
421   }
422 
423   Tensor grad_input;
424   if (output_mask[0]) {
425     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
426     const auto result = convolution_backward_input_batch_rule(
427         grad_output, grad_output_bdim,
428         input, input_bdim,
429         weight, weight_bdim,
430         stride, padding, dilation, transposed, output_padding, groups);
431     grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
432   }
433 
434   Tensor grad_weight;
435   if (output_mask[1]) {
436     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
437     const auto result = convolution_backward_weight_batch_rule(
438         grad_output, grad_output_bdim,
439         input, input_bdim,
440         weight, weight_bdim,
441         stride, padding, dilation, transposed, output_padding, groups);
442     grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
443   }
444   return std::make_tuple(grad_input, grad_weight, grad_bias);
445 
446   // Someone's definitely going to find a problem with this batching rule so
447   // I'm leaving the following fallback if we need it back.
448   // static auto op = c10::Dispatcher::singleton()
449   //   .findSchemaOrThrow("aten::convolution_backward", "");
450   // auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
451   //   grad_output_, input_, weight_, bias_sizes_opt,
452   //   stride, padding, dilation, transposed, output_padding, groups, output_mask
453   // });
454   // return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
455 }
456 
457 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)458 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
459   VMAP_SUPPORT(convolution, convolution_batch_rule);
460   m.impl("_convolution", _convolution_decomp);
461   m.impl("convolution_backward", convolution_backward_plumbing);
462 }
463 
464 } // namespace at;:functorch
465