1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/core/dispatch/Dispatcher.h>
10
11 namespace at::functorch {
12
13 // convolution_batch_rule translated from jax with modifications:
14 // https://github.com/google/jax/blob/master/jax/_src/lax/lax.py#L3143
15
16 // PyTorch's convolution is different from JAX's conv_general_dilated:
17 // we do not support batch_group_count (which is needed for convolution backwards).
18 // Instead, there's a convolution_backward op that needs a batching rule.
19 static std::tuple<Tensor, std::optional<int64_t>>
convolution_batch_rule(const Tensor & lhs,std::optional<int64_t> lhs_bdim,const Tensor & rhs,std::optional<int64_t> rhs_bdim,const std::optional<Tensor> & bias,std::optional<int64_t> bias_bdim,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,c10::SymInt groups)20 convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const Tensor& rhs, std::optional<int64_t> rhs_bdim, const std::optional<Tensor>& bias, std::optional<int64_t> bias_bdim, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
21 DimVector lhs_spec(stride.size() + 2);
22 std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
23 DimVector rhs_spec = lhs_spec;
24 DimVector out_spec = lhs_spec;
25 if (transposed) {
26 rhs_spec[0] = 1;
27 rhs_spec[1] = 0;
28 }
29
30 // If we have a batched bias or weight, we need to perform the computation separately.
31 std::optional<Tensor> unbatched_bias;
32 bool separate_bias = false;
33 if ((rhs_bdim && bias && bias->defined()) || bias_bdim) {
34 TORCH_INTERNAL_ASSERT(bias.has_value());
35 TORCH_INTERNAL_ASSERT(bias->defined());
36 unbatched_bias = std::nullopt;
37 separate_bias = true;
38 } else {
39 unbatched_bias = bias;
40 separate_bias = false;
41 }
42 std::tuple<Tensor, std::optional<int64_t>> result;
43 if (lhs_bdim && !rhs_bdim) {
44 auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs);
45 auto out = at::convolution_symint(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
46 out = reshape_dim_outof_symint(out_spec[0], lhs.sizes()[*lhs_bdim], out);
47 result = std::make_tuple(out, out_spec[0]);
48 } else if (!lhs_bdim && rhs_bdim) {
49 if (groups == 1) {
50 auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs);
51 auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
52 out = reshape_dim_outof_symint(out_spec[1], rhs.size(*rhs_bdim), out);
53 result = std::make_tuple(out, out_spec[1]);
54 } else {
55 if (transposed) {
56 // conv_transpose with groups is normally NIHW, IOHW -> N(GO)HW
57 // With RHS batched, we do the following:
58 // NIHW, BIOHW -> NIHW, I(BO)HW -> N(GBO)HW -> BN(GO)HW
59 // NB: the following isn't written using rhs_spec
60 // (PyTorch convs have a fixed dimension order)
61
62 // BIOHW -> I(BO)HW
63 auto new_w = reshape_dim_into(*rhs_bdim, 1, rhs);
64 // NIHW, I(BO)HW -> N(GBO)HW
65 auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
66 // N(GBO)HW -> NG(BO)HW
67 out = reshape_dim_outof_symint(1, groups, out);
68 // NG(BO)HW -> NGBOHW
69 out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
70 // NGBOHW -> NB(GO)HW
71 out = reshape_dim_into(1, 2, out);
72 result = std::make_tuple(out, 1);
73 } else {
74 // conv with groups is normally N(GI)HW, (GO)IHW -> N(GO)HW
75 // With RHS batched, we do the following:
76 // N(GI)HW, B(GO)IHW -> N(GI)HW, (GBO)IHW -> N(GBO)HW -> BN(GO)HW
77 // NB: the following isn't written using rhs_spec
78 // (PyTorch convs have a fixed dimension order)
79
80 // B(GO)IHW -> BGOIHW
81 auto new_w = reshape_dim_outof_symint(0 + (*rhs_bdim == 0), groups, rhs);
82 // BGOIHW -> G(BO)IHW
83 new_w = reshape_dim_into(*rhs_bdim + (*rhs_bdim > 0), 1, new_w);
84 // G(BO)IHW -> (GBO)IHW
85 new_w = reshape_dim_into(0, 0, new_w);
86 // N(GI)HW, (GBO)IHW -> N(GBO)HW
87 auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
88 // N(GBO)HW -> NG(BO)HW
89 out = reshape_dim_outof_symint(1, groups, out);
90 // NG(BO)HW -> NGBOHW
91 out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
92 // NGBOHW -> NB(GO)HW
93 out = reshape_dim_into(1, 2, out);
94 result = std::make_tuple(out, 1);
95 }
96 }
97 } else if (lhs_bdim && rhs_bdim) {
98 auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs);
99 groups *= lhs.sizes()[*lhs_bdim];
100 auto dim_with_groups = transposed ? 1 : 0;
101 auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs);
102 auto out = at::convolution_symint(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
103 out = reshape_dim_outof_symint(out_spec[1], lhs.sizes()[*lhs_bdim], out);
104 result = std::make_tuple(out, out_spec[1]);
105 } else {
106 result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), std::nullopt);
107 }
108 if (separate_bias) {
109 auto A = std::get<0>(result);
110 auto A_batch_dim = std::get<1>(result);
111 auto B = *bias;
112 auto B_batch_dim = bias_bdim;
113 A = moveBatchDimToFront(A, A_batch_dim);
114 B = moveBatchDimToFront(B, B_batch_dim);
115 for (size_t i = 0; i < out_spec.size() - 2; i++) {
116 B = B.unsqueeze(-1);
117 }
118 B = maybePadToLogicalRank(B, B_batch_dim, rankWithoutBatchDim(A, A_batch_dim));
119
120 return std::make_tuple(at::add(A, B), 0);
121 } else {
122 return result;
123 }
124 }
125
_convolution_decomp(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,bool benchmark,bool deterministic,bool cudnn_enabled,bool allow_tf32)126 static Tensor _convolution_decomp(
127 const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
128 IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
129 bool transposed_, IntArrayRef output_padding_, int64_t groups_,
130 bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
131 // Ignore everything. If the user called this in the normal way,
132 // then they should be fine.
133 (void) benchmark;
134 (void) deterministic;
135 (void) cudnn_enabled;
136 (void) allow_tf32;
137 return at::convolution(
138 input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
139 }
140
compute_grad_bias(const Tensor & grad_output_,std::array<bool,3> output_mask)141 static Tensor compute_grad_bias(
142 const Tensor& grad_output_, std::array<bool, 3> output_mask) {
143 if (!output_mask[2]) {
144 return Tensor();
145 }
146 DimVector reduce_dims;
147 reduce_dims.resize(grad_output_.dim() - 1);
148 reduce_dims[0] = 0;
149 std::iota(reduce_dims.begin() + 1, reduce_dims.end(), 2);
150 return grad_output_.sum(reduce_dims);
151 }
152
153 // reshapes the batch_size into dim
make_dummy(const Tensor & tensor,std::optional<int64_t> tensor_bdim,int64_t dim,int64_t batch_size)154 static Tensor make_dummy(
155 const Tensor& tensor, std::optional<int64_t> tensor_bdim,
156 int64_t dim, int64_t batch_size) {
157 auto tensor_ = tensor_bdim ? tensor.select(*tensor_bdim, 0) : tensor;
158 auto orig_size = tensor_.size(dim);
159 tensor_ = tensor_.slice(dim, 0, 1);
160
161 DimVector expand_shape(tensor_.sizes().begin(), tensor_.sizes().end());
162 expand_shape[dim] = batch_size * orig_size;
163
164 return tensor_.new_empty({}).expand(expand_shape);
165 }
166
167 static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_input_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & weight,std::optional<int64_t> weight_bdim,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,const c10::SymInt & groups)168 convolution_backward_input_batch_rule(
169 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
170 const Tensor& input, std::optional<int64_t> input_bdim,
171 const Tensor& weight, std::optional<int64_t> weight_bdim,
172 c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
173 c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
174 const std::array<bool, 3> mask = {true, false, false};
175 if (grad_output_bdim && weight_bdim) {
176 // regular: BNO, BOI -> N(BO), (BO)I -> N(BI)
177 // transposed: BNO, BIO -> N(BO), (BI)O -> N(BI)
178 const auto batch_size = weight.size(*weight_bdim);
179 const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
180 const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
181 auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
182 const auto result = at::convolution_backward_symint(
183 grad_output_, dummy_input, weight_, std::nullopt, stride, padding,
184 dilation, transposed, output_padding, groups * batch_size, mask);
185 const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
186 return std::make_tuple(grad_input, 1);
187 } else if (grad_output_bdim && !weight_bdim) {
188 // BNO, OI -> (BN)O, OI -> (BN)I
189 // transposed is the same.
190 const auto batch_size = grad_output.size(*grad_output_bdim);
191 const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
192 auto dummy_input = make_dummy(input, input_bdim, 0, batch_size);
193 const auto result = at::convolution_backward_symint(
194 grad_output_, dummy_input, weight, std::nullopt, stride, padding,
195 dilation, transposed, output_padding, groups, mask);
196 const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
197 return std::make_tuple(grad_input, 0);
198 } else if (!grad_output_bdim && weight_bdim) {
199 const auto batch_size = weight.size(*weight_bdim);
200 if (groups == 1) {
201 // regular: NO, BOI -> NO, O(BI) -> N(BI)
202 // transposed: NO, BIO -> NO, (BI)O -> N(BI)
203 const auto in_ch_dim = transposed ? 0 : 1;
204 const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight);
205 auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
206 const auto result = at::convolution_backward_symint(
207 grad_output, dummy_input, weight_, std::nullopt, stride, padding,
208 dilation, transposed, output_padding, groups, mask);
209 const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
210 return std::make_tuple(grad_input, 1);
211 }
212 Tensor grad_input;
213 if (!transposed) {
214 // N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
215 const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
216 auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
217 const auto result = at::convolution_backward_symint(
218 grad_output, dummy_input, weight_, std::nullopt, stride, padding,
219 dilation, transposed, output_padding, groups, mask);
220 grad_input = std::get<0>(result); // N(GBI)
221 } else {
222 // N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
223 auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O
224 weight_ = reshape_dim_outof_symint(1, groups, weight_); // BGIO
225 weight_ = weight_.transpose(0, 1); // GBIO
226 weight_ = weight_.flatten(0, 2); // (GBI)O
227 const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
228 const auto result = at::convolution_backward_symint(
229 grad_output, dummy_input, weight_, std::nullopt, stride, padding,
230 dilation, transposed, output_padding, groups, mask);
231 grad_input = std::get<0>(result); // N(GBI)
232 }
233 // N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
234 grad_input = reshape_dim_outof_symint(1, groups, grad_input);
235 grad_input = reshape_dim_outof_symint(2, batch_size, grad_input);
236 grad_input = grad_input.transpose(1, 2);
237 grad_input = reshape_dim_into(2, 2, grad_input);
238 return std::make_tuple(grad_input, 1);
239 } else {
240 TORCH_INTERNAL_ASSERT(input_bdim);
241 const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
242 const auto result = at::convolution_backward_symint(
243 grad_output, dummy_input, weight, std::nullopt, stride, padding,
244 dilation, transposed, output_padding, groups, mask);
245 return std::make_tuple(std::get<0>(result), std::nullopt);
246 }
247 }
248 static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_weight_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & weight,std::optional<int64_t> weight_bdim,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,const c10::SymInt & groups)249 convolution_backward_weight_batch_rule(
250 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
251 const Tensor& input, std::optional<int64_t> input_bdim,
252 const Tensor& weight, std::optional<int64_t> weight_bdim,
253 c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
254 c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
255 const std::array<bool, 3> mask = {false, true, false};
256 if (grad_output_bdim && input_bdim) {
257 // BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed)
258 const auto batch_size = input.size(*input_bdim);
259 const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
260 const auto input_ = reshape_dim_into(*input_bdim, 1, input);
261 const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
262 const auto result = at::convolution_backward_symint(
263 grad_output_, input_, dummy_weight, std::nullopt, stride, padding,
264 dilation, transposed, output_padding, groups * batch_size, mask);
265 auto grad_weight = std::get<1>(result);
266 grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight);
267 return std::make_tuple(grad_weight, 0);
268 } else if (grad_output_bdim && !input_bdim) {
269 const auto batch_size = grad_output.size(*grad_output_bdim);
270 if (groups == 1) {
271 // regular: BNO, NI -> N(BO), NI -> (BO)I
272 // transposed: BNO, NI -> N(BO), NI -> I(BO)
273 const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
274 const auto out_ch_dim = transposed ? 1 : 0;
275 const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
276 const auto result = at::convolution_backward_symint(
277 grad_output_, input, dummy_weight, std::nullopt, stride, padding,
278 dilation, transposed, output_padding, groups, mask);
279 auto grad_weight = std::get<1>(result);
280 grad_weight = reshape_dim_outof_symint(out_ch_dim, batch_size, grad_weight);
281 return std::make_tuple(grad_weight, out_ch_dim);
282 } else {
283 auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); // BN(GO)
284 grad_output_ = reshape_dim_outof_symint(2, groups, grad_output_); // BNGO
285 grad_output_ = grad_output_.movedim(0, 2); // NGBO
286 grad_output_ = grad_output_.flatten(1, 3); // N(GBO)
287 if (!transposed) {
288 // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
289 const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
290 const auto result = at::convolution_backward_symint(
291 grad_output_, input, dummy_weight, std::nullopt, stride, padding,
292 dilation, transposed, output_padding, groups, mask);
293 auto grad_weight = std::get<1>(result);
294 grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBOI
295 grad_weight = grad_weight.transpose(0, 1); // BGOI
296 grad_weight = grad_weight.flatten(1, 2); // B(GO)I
297 return std::make_tuple(grad_weight, 0);
298 } else {
299 // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
300 const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
301 const auto result = at::convolution_backward_symint(
302 grad_output_, input, dummy_weight, std::nullopt, stride, padding,
303 dilation, transposed, output_padding, groups, mask);
304 auto grad_weight = std::get<1>(result);
305 grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
306 return std::make_tuple(grad_weight, 1);
307 }
308 }
309 } else if (!grad_output_bdim && input_bdim) {
310 const auto batch_size = input.size(*input_bdim);
311 if (groups == 1) {
312 // regular: NO, BNI -> NO, N(BI) -> O(BI)
313 // transposed: NO, BNI -> NO, N(BI) -> (BI)O
314 const auto input_ = reshape_dim_into(*input_bdim, 1, input);
315 const auto in_ch_dim = transposed ? 0 : 1;
316 const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
317 const auto result = at::convolution_backward_symint(
318 grad_output, input_, dummy_weight, std::nullopt, stride, padding,
319 dilation, transposed, output_padding, groups, mask);
320 auto grad_weight = std::get<1>(result);
321 grad_weight = reshape_dim_outof_symint(in_ch_dim, batch_size, grad_weight);
322 return std::make_tuple(grad_weight, in_ch_dim);
323 } else {
324 auto input_ = moveBatchDimToFront(input, input_bdim); // BN(GI)
325 input_ = reshape_dim_outof_symint(2, groups, input_); // BNGI
326 input_ = input_.movedim(0, 2); // NGBI
327 input_ = input_.flatten(1, 3); // N(GBI)
328 if (!transposed) {
329 // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
330 const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
331 const auto result = at::convolution_backward_symint(
332 grad_output, input_, dummy_weight, std::nullopt, stride, padding,
333 dilation, transposed, output_padding, groups, mask);
334 auto grad_weight = std::get<1>(result);
335 grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
336 return std::make_tuple(grad_weight, 1);
337 } else {
338 // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
339 const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
340 const auto result = at::convolution_backward_symint(
341 grad_output, input_, dummy_weight, std::nullopt, stride, padding,
342 dilation, transposed, output_padding, groups, mask);
343 auto grad_weight = std::get<1>(result);
344 grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBIO
345 grad_weight = grad_weight.transpose(0, 1); // BGIO
346 grad_weight = grad_weight.flatten(1, 2); // B(GI)O
347 return std::make_tuple(grad_weight, 0);
348 }
349 }
350 } else {
351 TORCH_INTERNAL_ASSERT(weight_bdim);
352 const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
353 const auto result = at::convolution_backward_symint(
354 grad_output, input, dummy_weight, std::nullopt, stride, padding,
355 dilation, transposed, output_padding, groups, mask);
356 return std::make_tuple(std::get<1>(result), std::nullopt);
357
358 }
359 }
360
convolution_backward_plumbing(const Tensor & grad_output_,const Tensor & input_,const Tensor & weight_,const c10::OptionalArrayRef<SymInt> bias_sizes_opt,c10::SymIntArrayRef stride,c10::SymIntArrayRef padding,c10::SymIntArrayRef dilation,bool transposed,c10::SymIntArrayRef output_padding,c10::SymInt groups,std::array<bool,3> output_mask)361 static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
362 const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
363 const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
364 c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
365 c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array<bool, 3> output_mask) {
366 const auto maybe_layer = maybeCurrentDynamicLayer();
367 vmap_check_escaped(maybe_layer, "convolution_backward_plumbing");
368 int64_t cur_level = maybe_layer->layerId();
369
370 if (!areAnyBatchedAtLevel({grad_output_, input_, weight_}, cur_level)){
371 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
372 return at::convolution_backward_symint(
373 grad_output_, input_, weight_, bias_sizes_opt, stride, padding,
374 dilation, transposed, output_padding, groups, output_mask);
375 }
376
377 auto [grad_output, grad_output_bdim] = unwrapTensorAtLevel(grad_output_, cur_level);
378 auto [input, input_bdim] = unwrapTensorAtLevel(input_, cur_level);
379 auto [weight, weight_bdim] = unwrapTensorAtLevel(weight_, cur_level);
380
381 const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
382 output_mask[2] = false;
383
384 // TODO: A little bird says that unfold + matmul is actually faster than
385 // group convolution in many cases. We should benchmark some of
386 // the common cases and replace things with unfold + matmul as necessary.
387
388 // Notation:
389 // B - a batch dimension
390 // G - groups (sometimes omitted because it doesn't matter)
391 // NO - grad_output
392 // NI - input
393 // OI - weight
394 // "(BO)I" - we don't actually care about the values of this Tensor,
395 // we just need to create a tensor on the same device with the
396 // correct shape and pray that the implementation is smart enough
397 // to not do anything with it.
398
399 // BNO, BNI, BOI
400 // AKA one of the model ensembling case
401 if (grad_output_bdim && input_bdim && weight_bdim) {
402 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
403 grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output);
404
405 // BNO, BNI, BOI -> N(BO), N(BI), (BO)I
406 const auto batch_size = weight.size(*weight_bdim);
407 input = reshape_dim_into(*input_bdim, 1, input);
408 weight = reshape_dim_into(*weight_bdim, 0, weight);
409 const auto result = at::convolution_backward_symint(
410 grad_output, input, weight, std::nullopt, stride, padding, dilation,
411 transposed, output_padding, batch_size * groups, output_mask);
412 // N(BI), (BO)I -> NBI, BOI
413 const auto grad_input = output_mask[0] ?
414 reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
415 const auto grad_weight = output_mask[1] ?
416 reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
417 return std::make_tuple(
418 output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
419 output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
420 grad_bias);
421 }
422
423 Tensor grad_input;
424 if (output_mask[0]) {
425 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
426 const auto result = convolution_backward_input_batch_rule(
427 grad_output, grad_output_bdim,
428 input, input_bdim,
429 weight, weight_bdim,
430 stride, padding, dilation, transposed, output_padding, groups);
431 grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
432 }
433
434 Tensor grad_weight;
435 if (output_mask[1]) {
436 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
437 const auto result = convolution_backward_weight_batch_rule(
438 grad_output, grad_output_bdim,
439 input, input_bdim,
440 weight, weight_bdim,
441 stride, padding, dilation, transposed, output_padding, groups);
442 grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
443 }
444 return std::make_tuple(grad_input, grad_weight, grad_bias);
445
446 // Someone's definitely going to find a problem with this batching rule so
447 // I'm leaving the following fallback if we need it back.
448 // static auto op = c10::Dispatcher::singleton()
449 // .findSchemaOrThrow("aten::convolution_backward", "");
450 // auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
451 // grad_output_, input_, weight_, bias_sizes_opt,
452 // stride, padding, dilation, transposed, output_padding, groups, output_mask
453 // });
454 // return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
455 }
456
457
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)458 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
459 VMAP_SUPPORT(convolution, convolution_batch_rule);
460 m.impl("_convolution", _convolution_decomp);
461 m.impl("convolution_backward", convolution_backward_plumbing);
462 }
463
464 } // namespace at;:functorch
465