1 #include <ATen/RedispatchFunctions.h>
2 #include <ATen/TracerMode.h>
3 #include <ATen/core/op_registration/op_registration.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/autograd/FunctionsManual.h>
7 #include <torch/csrc/autograd/VariableTypeUtils.h>
8 #include <torch/csrc/autograd/autograd.h>
9 #include <torch/csrc/autograd/functions/utils.h>
10 #include <torch/csrc/autograd/generated/VariableType.h>
11 #include <torch/csrc/autograd/generated/ViewFuncs.h>
12 #include <torch/library.h>
13 #include <optional>
14
15 #include <utility>
16
17 using namespace at;
18 using namespace torch::autograd::generated;
19 using torch::autograd::as_view;
20 using torch::autograd::CreationMeta;
21
22 namespace torch {
23
24 namespace autograd::VariableType {
25
allTypesForBackends(at::ArrayRef<at::Backend> backends)26 static std::vector<at::DeprecatedTypeProperties*> allTypesForBackends(
27 at::ArrayRef<at::Backend> backends) {
28 std::vector<DeprecatedTypeProperties*> res;
29 res.reserve(backends.size());
30 for (auto p : backends) {
31 for (const auto s :
32 c10::irange(static_cast<int64_t>(ScalarType::NumOptions))) {
33 auto& type = getDeprecatedTypeProperties(
34 static_cast<Backend>(p), static_cast<ScalarType>(s));
35 res.emplace_back(&type);
36 }
37 }
38 return res;
39 }
40
allCPUTypes()41 std::vector<at::DeprecatedTypeProperties*> allCPUTypes() {
42 return allTypesForBackends({Backend::CPU, Backend::SparseCPU});
43 }
44
allCUDATypes()45 std::vector<at::DeprecatedTypeProperties*> allCUDATypes() {
46 at::globalContext().lazyInitCUDA();
47 return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA});
48 }
49
allXPUTypes()50 std::vector<at::DeprecatedTypeProperties*> allXPUTypes() {
51 return allTypesForBackends({Backend::XPU, Backend::SparseXPU});
52 }
53
allPrivateUser1Types()54 std::vector<at::DeprecatedTypeProperties*> allPrivateUser1Types() {
55 at::globalContext().lazyInitPrivateUse1();
56 return allTypesForBackends(
57 {Backend::PrivateUse1, Backend::SparsePrivateUse1});
58 }
59
60 namespace {
checked_cast_variable(const Tensor & t,const char * name,int pos)61 const Variable& checked_cast_variable(
62 const Tensor& t,
63 const char* name,
64 int pos) {
65 if (!t.defined()) {
66 AT_ERROR(
67 "Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
68 "for argument #",
69 pos,
70 " '",
71 name,
72 "'");
73 }
74 return t;
75 }
76
checked_cast_variable(Tensor & t,const char * name,int pos)77 Variable& checked_cast_variable(Tensor& t, const char* name, int pos) {
78 if (!t.defined()) {
79 AT_ERROR(
80 "Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
81 "for argument #",
82 pos,
83 " '",
84 name,
85 "'");
86 }
87 return t;
88 }
89 } // namespace
90
unpack(const Tensor & t,const char * name,int pos)91 const Tensor& unpack(const Tensor& t, const char* name, int pos) {
92 return checked_cast_variable(t, name, pos);
93 }
94
unpack(Tensor & t,const char * name,int pos)95 Tensor& unpack(Tensor& t, const char* name, int pos) {
96 return checked_cast_variable(t, name, pos);
97 }
98
unpack_opt(const Tensor & t,const char * name,int pos)99 Tensor unpack_opt(const Tensor& t, const char* name, int pos) {
100 if (!t.defined()) {
101 return Tensor();
102 }
103 return unpack(t, name, pos);
104 }
105
unpack(const at::ITensorListRef & tl,const char * name,int pos)106 std::vector<at::Tensor> unpack(
107 const at::ITensorListRef& tl,
108 const char* name,
109 int pos) {
110 std::vector<at::Tensor> ret;
111 ret.reserve(tl.size());
112 for (const auto& t : tl) {
113 ret.push_back(t);
114 }
115 return ret;
116 }
117
118 namespace {
119
120 // Taken from codegened version
_fw_primal(c10::DispatchKeySet ks,const Tensor & self,int64_t level)121 Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) {
122 auto& self_ = unpack(self, "self", 0);
123 std::shared_ptr<Identity> grad_fn;
124 if (compute_requires_grad(self)) {
125 grad_fn = std::make_shared<Identity>();
126 grad_fn->set_next_edges(collect_next_edges(self));
127 }
128
129 auto result = ([&]() {
130 at::AutoDispatchBelowAutograd guard;
131 return at::redispatch::_fw_primal(
132 ks & c10::after_autograd_keyset, self_, level);
133 })();
134
135 if (grad_fn) {
136 set_history(flatten_tensor_args(result), grad_fn);
137 }
138 if (isFwGradDefined(self)) {
139 // Modified from original codegen
140 // We explicitly want to ignore the forward grad at the given level
141 TORCH_CHECK(level == 0, "Invalid level given to _fw_primal");
142 // End modified from original codegen
143 }
144 return result;
145 }
146
147 // NB: We need a manual variable type kernel so that set_fw_grad properly
148 // detects that _make_dual is not a forward-differentiable view
149 //
150 // This function can be used to create a dual Tensor that holds a tangent to
151 // compute forward mode gradients. Note that the dual Tensor's primal is a view
152 // of the given primal and the given tangent is used as-is. This function is
153 // backward differentiable.
_make_dual(c10::DispatchKeySet ks,const Tensor & primal,const Tensor & tangent,int64_t level)154 Tensor _make_dual(
155 c10::DispatchKeySet ks,
156 const Tensor& primal,
157 const Tensor& tangent,
158 int64_t level) {
159 TORCH_CHECK(
160 !primal._fw_grad(level).defined(),
161 "Making a dual Tensor based on a Tensor that "
162 "already has a forward gradient at the same level ",
163 level,
164 " is not supported.");
165 auto& primal_ = unpack(primal, "primal", 0);
166 auto& tangent_ = unpack(tangent, "tangent", 0);
167 std::shared_ptr<ViewBackward0> grad_fn;
168 if (compute_requires_grad(primal_)) {
169 grad_fn = std::make_shared<ViewBackward0>();
170 grad_fn->self_sym_sizes = primal_.sym_sizes().vec();
171 grad_fn->set_next_edges(collect_next_edges(primal_));
172 }
173
174 auto result = ([&]() {
175 at::AutoDispatchBelowAutograd guard;
176 return at::redispatch::_make_dual(
177 ks & c10::after_autograd_keyset, primal_, tangent_, level);
178 })();
179
180 if (grad_fn) {
181 set_history(flatten_tensor_args(result), grad_fn);
182 }
183
184 TORCH_CHECK(level == 0, "Invalid level given to _make_dual");
185 result._set_fw_grad(tangent_, level, /* is_inplace_op */ false);
186 return result;
187 }
188
189 // We don't have an outplace copy, so this can't be generated automatically
copy_(c10::DispatchKeySet ks,Tensor & self,const Tensor & src,bool non_blocking)190 Tensor& copy_(
191 c10::DispatchKeySet ks,
192 Tensor& self,
193 const Tensor& src,
194 bool non_blocking) {
195 // TODO: once copy is exposed in Declarations.yaml we may be able to bind
196 // it automatically
197 auto& self_ = unpack(self, "self", 0);
198 auto& src_ = unpack(src, "src", 1);
199 std::shared_ptr<CopyBackwards> grad_fn;
200 auto requires_grad = compute_requires_grad(self, src);
201 requires_grad &= isDifferentiableType(self.scalar_type());
202 check_inplace(self, requires_grad);
203 if (requires_grad) {
204 grad_fn = std::make_shared<CopyBackwards>();
205 grad_fn->set_next_edges(collect_next_edges(self, src));
206 grad_fn->src_options = src.options();
207 }
208 {
209 at::AutoDispatchBelowAutograd mode;
210 at::redispatch::copy_(
211 ks & c10::after_autograd_keyset, self_, src_, non_blocking);
212 }
213 rebase_history(self, std::move(grad_fn));
214
215 if (isDifferentiableType(self.scalar_type()) &&
216 (isFwGradDefined(self) || isFwGradDefined(src))) {
217 auto self_fw_grad = generated::details::toNonOptFwGrad(self);
218 auto src_fw_grad = generated::details::toNonOptFwGrad(src);
219 Tensor new_fw_grad;
220 if (self_fw_grad.defined()) {
221 if (src_fw_grad.defined()) {
222 new_fw_grad = self_fw_grad.copy_(src_fw_grad);
223 } else {
224 new_fw_grad = self_fw_grad.fill_(0);
225 }
226 } else {
227 if (!self.is_same_size(src_fw_grad)) {
228 new_fw_grad = src_fw_grad.broadcast_to(self.sizes());
229 } else {
230 new_fw_grad = src_fw_grad.clone();
231 }
232 }
233 self._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true);
234 }
235
236 return self;
237 }
238
resize_(c10::DispatchKeySet ks,const Tensor & self,SymIntArrayRef size,std::optional<MemoryFormat> optional_memory_format)239 const Tensor& resize_(
240 c10::DispatchKeySet ks,
241 const Tensor& self,
242 SymIntArrayRef size,
243 std::optional<MemoryFormat> optional_memory_format) {
244 auto& self_ = unpack(self, "self", 0);
245 if (self.requires_grad()) {
246 AT_ERROR("cannot resize variables that require grad");
247 }
248 {
249 at::AutoDispatchBelowAutograd mode;
250 at::redispatch::resize__symint(
251 ks & c10::after_autograd_keyset, self_, size, optional_memory_format);
252 }
253
254 if (self._fw_grad(/* level */ 0).defined()) {
255 AT_ERROR("cannot resize variables that has a forward grad");
256 }
257
258 return self;
259 }
260
resize_as_(c10::DispatchKeySet ks,const Tensor & self,const Tensor & the_template,std::optional<MemoryFormat> optional_memory_format)261 const Tensor& resize_as_(
262 c10::DispatchKeySet ks,
263 const Tensor& self,
264 const Tensor& the_template,
265 std::optional<MemoryFormat> optional_memory_format) {
266 auto& self_ = unpack(self, "self", 0);
267 auto& the_template_ = unpack(the_template, "the_template", 1);
268 if (self.requires_grad()) {
269 AT_ERROR("cannot resize variables that require grad");
270 }
271 {
272 at::AutoDispatchBelowAutograd mode;
273 at::redispatch::resize_as_(
274 ks & c10::after_autograd_keyset,
275 self_,
276 the_template_,
277 optional_memory_format);
278 }
279
280 // Handle fw grad
281 if (self._fw_grad(/* level */ 0).defined()) {
282 AT_ERROR("cannot resize variables that has a forward grad");
283 }
284
285 return self;
286 }
287
detach(c10::DispatchKeySet ks,const Tensor & self)288 Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
289 auto& self_ = unpack(self, "self", 0);
290 RECORD_FUNCTION("detach", std::vector<c10::IValue>({self}));
291 auto result = ([&]() {
292 at::AutoDispatchBelowAutograd guard;
293 return at::redispatch::detach(ks & c10::after_autograd_keyset, self_);
294 })();
295 namedinference::propagate_names(result, self);
296
297 // Detach the forward grads by not setting anything on the result
298
299 return result;
300 }
301
detach_(c10::DispatchKeySet ks,Tensor & self)302 Tensor& detach_(c10::DispatchKeySet ks, Tensor& self) {
303 RECORD_FUNCTION("detach_", std::vector<c10::IValue>({self}));
304 if (self.is_view()) {
305 // See NOTE [ View + Inplace detection ]
306 AT_ERROR(
307 "Can't detach views in-place. Use detach() instead. "
308 "If you are using DistributedDataParallel (DDP) for training, "
309 "and gradient_as_bucket_view is set as True, gradients are "
310 "views of DDP buckets, and hence detach_() cannot be called "
311 "on these gradients. To fix this error, please refer to the "
312 "Optimizer.zero_grad() function in torch/optim/optimizer.py "
313 "as the solution.");
314 }
315 // I think the choice here is conservative. In principle, doing
316 // an in-place detach should give us the ability to just clear
317 // the autograd meta. But this function ONLY resets requires_grad,
318 // grad_fn and output_nr; there's other metadata like debug name
319 // and hooks which aren't cleared. Is this function supposed to
320 // clear those too? I'm not too sure, so I'm leaving it be for now.
321 auto autograd_meta = impl::materialize_autograd_meta(self);
322 autograd_meta->set_requires_grad(false, self.unsafeGetTensorImpl());
323 autograd_meta->grad_fn_.reset();
324 autograd_meta->output_nr_ = 0;
325 autograd_meta->fw_grad_.reset();
326
327 return self;
328 }
329
330 // Ops in the following registration list are registered as
331 // (1) CompositeImplicitAutograd kernels
332 // (2) Autograd kernels
333 // (3) CompositeExplicitAutograd kernels and additionally Autograd kernels
334 // The reason for (3) is that ops that also use dispatch (e.g. register
335 // CPU/CUDA/QuantizedCPU kernels) will skip picking up CompositeImplicitAutograd
336 // kernels for Autograd, so we register them to both CompositeExplicitAutograd
337 // and Autograd instead. See
338 // https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword
339 // for more details.
340 // Invariant:
341 // - Ops registered to CompositeImplicitAutograd or CompositeExplicitAutograd
342 // below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py.
343 // and they have manual_kernel_registration=True in native_functions.yaml.
344 // - Ops registered to DispatchKey::Autograd below must be included in
345 // `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py
346
TORCH_LIBRARY_IMPL(aten,Autograd,m)347 TORCH_LIBRARY_IMPL(aten, Autograd, m) {
348 m.impl(
349 "resize_",
350 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_)));
351 m.impl(
352 "resize_as_",
353 torch::dispatch(
354 DispatchKey::Autograd, TORCH_FN(VariableType::resize_as_)));
355 m.impl(
356 "detach",
357 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach)));
358 m.impl(
359 "detach_",
360 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach_)));
361 m.impl(
362 "copy_",
363 torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::copy_)));
364 m.impl(
365 "_fw_primal",
366 torch::dispatch(
367 DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal)));
368 m.impl(
369 "_make_dual",
370 torch::dispatch(
371 DispatchKey::Autograd, TORCH_FN(VariableType::_make_dual)));
372 }
373
374 } // namespace
375 } // namespace autograd::VariableType
376
377 namespace ADInplaceOrView {
378 #define CREATION_META_DEFINITION \
379 InferenceMode::is_enabled() \
380 ? CreationMeta::INFERENCE_MODE \
381 : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT \
382 : CreationMeta::NO_GRAD_MODE)
383
copy_(c10::DispatchKeySet ks,Tensor & self,const Tensor & src,bool non_blocking)384 static Tensor& copy_(
385 c10::DispatchKeySet ks,
386 Tensor& self,
387 const Tensor& src,
388 bool non_blocking) {
389 {
390 at::AutoDispatchBelowADInplaceOrView guard;
391 at::redispatch::copy_(
392 ks & c10::after_ADInplaceOrView_keyset, self, src, non_blocking);
393 }
394 torch::autograd::increment_version(self);
395 return self;
396 }
397
resize_(c10::DispatchKeySet ks,const Tensor & self,SymIntArrayRef size,std::optional<MemoryFormat> optional_memory_format)398 static const Tensor& resize_(
399 c10::DispatchKeySet ks,
400 const Tensor& self,
401 SymIntArrayRef size,
402 std::optional<MemoryFormat> optional_memory_format) {
403 // Hold sizes to verify if we actually resize `self`.
404 // Explicitly copy data, since resizing can move original data
405 // and make references invalid.
406 auto org_size = self.sym_sizes().vec();
407 {
408 at::AutoDispatchBelowADInplaceOrView guard;
409 at::redispatch::resize__symint(
410 ks & c10::after_ADInplaceOrView_keyset,
411 self,
412 size,
413 optional_memory_format);
414 }
415 // If `self` was resized, increment the version.
416 if (org_size != size) {
417 torch::autograd::increment_version(self);
418 }
419 return self;
420 }
421
resize_as_(c10::DispatchKeySet ks,const Tensor & self,const Tensor & the_template,std::optional<MemoryFormat> optional_memory_format)422 static const Tensor& resize_as_(
423 c10::DispatchKeySet ks,
424 const Tensor& self,
425 const Tensor& the_template,
426 std::optional<MemoryFormat> optional_memory_format) {
427 // Hold sizes to verify if we actually resize `self`.
428 // Explicitly copy data, since resizing can move original data
429 // and make references invalid.
430 auto org_size = self.sym_sizes().vec();
431 {
432 at::AutoDispatchBelowADInplaceOrView guard;
433 at::redispatch::resize_as_(
434 ks & c10::after_ADInplaceOrView_keyset,
435 self,
436 the_template,
437 optional_memory_format);
438 }
439
440 // If `self` was resized, increment the version.
441 if (org_size != the_template.sym_sizes()) {
442 torch::autograd::increment_version(self);
443 }
444 return self;
445 }
446
detach(c10::DispatchKeySet ks,const Tensor & self)447 static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
448 auto out = ([&]() {
449 at::AutoDispatchBelowADInplaceOrView guard;
450 return at::_ops::detach::redispatch(
451 ks & c10::after_ADInplaceOrView_keyset, self);
452 })();
453 // NB: we can't make detach() a normal view operator because the codegen
454 // generates allow_tensor_metadata_change = True for them. In the future we
455 // should have an option for this in the codegen.
456 auto result = as_view(
457 /* base */ self,
458 /* output */ out,
459 /* is_bw_differentiable */ false,
460 /* is_fw_differentiable */ false,
461 /* view_func */ nullptr,
462 /* rev_view_func */ nullptr,
463 /* creation_meta */ CreationMeta::DEFAULT,
464 /*allow_tensor_metadata_change=*/false);
465
466 return result;
467 }
468
_fw_primal(c10::DispatchKeySet ks,const Tensor & self,int64_t level)469 static Tensor _fw_primal(
470 c10::DispatchKeySet ks,
471 const Tensor& self,
472 int64_t level) {
473 auto tmp = ([&]() {
474 at::AutoDispatchBelowADInplaceOrView guard;
475 return at::alias(self);
476 })();
477 std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
478 std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
479 if (!self.unsafeGetTensorImpl()->support_as_strided()) {
480 func = std::make_unique<ViewViewFunc>(self.sym_sizes());
481 rev_func = [=](const at::Tensor& input_view) {
482 TORCH_INTERNAL_ASSERT(
483 false,
484 "Reverse view_func for _fw_primal() is not currently supported");
485 return Tensor();
486 };
487 }
488 auto result = as_view(
489 /* base */ self,
490 /* output */ tmp,
491 /* is_bw_differentiable */ true,
492 /* is_fw_differentiable */ false,
493 /* view_func */ std::move(func),
494 /* rev_view_func */ std::move(rev_func),
495 /* creation_meta */ CREATION_META_DEFINITION);
496
497 return result;
498 }
499
500 // NB: This does not redispatch any further
_make_dual(c10::DispatchKeySet ks,const Tensor & primal,const Tensor & tangent,int64_t level)501 static Tensor _make_dual(
502 c10::DispatchKeySet ks,
503 const Tensor& primal,
504 const Tensor& tangent,
505 int64_t level) {
506 auto tmp = ([&]() {
507 at::AutoDispatchBelowADInplaceOrView guard;
508 return at::alias(primal);
509 })();
510 std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
511 std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
512 if (!primal.unsafeGetTensorImpl()->support_as_strided()) {
513 func = std::make_unique<ViewViewFunc>(primal.sym_sizes());
514 rev_func = [=](const at::Tensor& input_view) {
515 TORCH_INTERNAL_ASSERT(
516 false,
517 "Reverse view_func for _make_dual() is not currently supported");
518 return Tensor();
519 };
520 }
521 auto result = as_view(
522 /* base */ primal,
523 /* output */ tmp,
524 /* is_bw_differentiable */ true,
525 /* is_fw_differentiable */ false,
526 /* view_func */ std::move(func),
527 /* rev_view_func */ std::move(rev_func),
528 /* creation_meta */ CREATION_META_DEFINITION);
529
530 return result;
531 }
532
533 namespace {
TORCH_LIBRARY_IMPL(aten,ADInplaceOrView,m)534 TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) {
535 m.impl(
536 "copy_",
537 torch::dispatch(
538 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::copy_)));
539 m.impl(
540 "detach",
541 torch::dispatch(
542 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::detach)));
543 m.impl(
544 "resize_",
545 torch::dispatch(
546 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::resize_)));
547 m.impl(
548 "resize_as_",
549 torch::dispatch(
550 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::resize_as_)));
551 m.impl(
552 "_fw_primal",
553 torch::dispatch(
554 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_fw_primal)));
555 m.impl(
556 "_make_dual",
557 torch::dispatch(
558 DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_make_dual)));
559 }
560 } // namespace
561 } // namespace ADInplaceOrView
562 } // namespace torch
563